Skip to content

swarmrl.losses.policy_gradient_loss Module API Reference

Module for the implementation of policy gradient loss.

Policy gradient is the most simplistic loss function where critic loss drives the entire policy learning.

Notes

https://spinningup.openai.com/en/latest/algorithms/vpg.html

PolicyGradientLoss

Bases: Loss

Parent class for the reinforcement learning tasks.

Notes

Source code in swarmrl/losses/policy_gradient_loss.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class PolicyGradientLoss(Loss):
    """
    Parent class for the reinforcement learning tasks.

    Notes
    -----
    """

    def __init__(self, value_function: ExpectedReturns = ExpectedReturns()):
        """
        Constructor for the reward class.

        Parameters
        ----------
        value_function : ExpectedReturns
        """
        super(Loss, self).__init__()
        self.value_function = value_function
        self.n_particles = None
        self.n_time_steps = None

    def _calculate_loss(
        self,
        network_params: FrozenDict,
        network: Network,
        feature_data: jnp.ndarray,
        action_indices: jnp.ndarray,
        rewards: jnp.ndarray,
    ) -> jnp.array:
        """
        Compute the loss of the shared actor-critic network.

        Parameters
        ----------
        network : FlaxModel
            The actor-critic network that approximates the policy.
        network_params : FrozenDict
            Parameters of the actor-critic model used.
        feature_data : np.ndarray (n_time_steps, n_particles, feature_dimension)
            Observable data for each time step and particle within the episode.
        action_indices : np.ndarray (n_time_steps, n_particles)
            The actions taken by the policy for all time steps and particles during one
            episode.
        rewards : np.ndarray (n_time_steps, n_particles)
            The rewards received for all time steps and particles during one episode.


        Returns
        -------
        loss : float
            The loss of the actor-critic network for the last episode.
        """

        # (n_timesteps, n_particles, n_possibilities)
        logits, predicted_values = network(network_params, feature_data)
        predicted_values = predicted_values.squeeze()
        probabilities = jax.nn.softmax(logits)  # get probabilities
        chosen_probabilities = gather_n_dim_indices(probabilities, action_indices)
        log_probs = jnp.log(chosen_probabilities + 1e-8)
        logger.debug(f"{log_probs.shape=}")

        returns = self.value_function(rewards)
        logger.debug(f"{returns.shape}")

        logger.debug(f"{predicted_values.shape=}")

        # (n_timesteps, n_particles)
        advantage = returns - predicted_values
        logger.debug(f"{advantage=}")

        actor_loss = -1 * ((log_probs * advantage).sum(axis=0)).sum()
        logger.debug(f"{actor_loss=}")

        # Sum over time steps and average over agents.
        critic_loss = optax.huber_loss(predicted_values, returns).sum(axis=0).sum()

        return actor_loss + critic_loss

    def compute_loss(self, network: Network, episode_data):
        """
        Compute the loss and update the shared actor-critic network.

        Parameters
        ----------
        network : Network
                actor-critic model to use in the analysis.
        episode_data : np.ndarray (n_timesteps, n_particles, feature_dimension)
                Observable data for each time step and particle within the episode.

        Returns
        -------

        """
        feature_data = jnp.array(episode_data.features)
        action_data = jnp.array(episode_data.actions)
        reward_data = jnp.array(episode_data.rewards)

        self.n_particles = jnp.shape(feature_data)[1]
        self.n_time_steps = jnp.shape(feature_data)[0]

        network_grad_fn = jax.value_and_grad(self._calculate_loss)
        _, network_grads = network_grad_fn(
            network.model_state.params,
            network=network,
            feature_data=feature_data,
            action_indices=action_data,
            rewards=reward_data,
        )

        network.update_model(network_grads)

__init__(value_function=ExpectedReturns())

Constructor for the reward class.

Parameters

value_function : ExpectedReturns

Source code in swarmrl/losses/policy_gradient_loss.py
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(self, value_function: ExpectedReturns = ExpectedReturns()):
    """
    Constructor for the reward class.

    Parameters
    ----------
    value_function : ExpectedReturns
    """
    super(Loss, self).__init__()
    self.value_function = value_function
    self.n_particles = None
    self.n_time_steps = None

compute_loss(network, episode_data)

Compute the loss and update the shared actor-critic network.

Parameters

network : Network actor-critic model to use in the analysis. episode_data : np.ndarray (n_timesteps, n_particles, feature_dimension) Observable data for each time step and particle within the episode.

Returns
Source code in swarmrl/losses/policy_gradient_loss.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def compute_loss(self, network: Network, episode_data):
    """
    Compute the loss and update the shared actor-critic network.

    Parameters
    ----------
    network : Network
            actor-critic model to use in the analysis.
    episode_data : np.ndarray (n_timesteps, n_particles, feature_dimension)
            Observable data for each time step and particle within the episode.

    Returns
    -------

    """
    feature_data = jnp.array(episode_data.features)
    action_data = jnp.array(episode_data.actions)
    reward_data = jnp.array(episode_data.rewards)

    self.n_particles = jnp.shape(feature_data)[1]
    self.n_time_steps = jnp.shape(feature_data)[0]

    network_grad_fn = jax.value_and_grad(self._calculate_loss)
    _, network_grads = network_grad_fn(
        network.model_state.params,
        network=network,
        feature_data=feature_data,
        action_indices=action_data,
        rewards=reward_data,
    )

    network.update_model(network_grads)