Skip to content

swarmrl.losses.proximal_policy_loss Module API Reference

Loss functions based on Proximal policy optimization.

Notes

https://spinningup.openai.com/en/latest/algorithms/ppo.html

ProximalPolicyLoss

Bases: Loss, ABC

Class to implement the proximal policy loss.

Source code in swarmrl/losses/proximal_policy_loss.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class ProximalPolicyLoss(Loss, ABC):
    """
    Class to implement the proximal policy loss.
    """

    def __init__(
        self,
        value_function: GAE = GAE(),
        sampling_strategy: SamplingStrategy = GumbelDistribution(),
        n_epochs: int = 20,
        epsilon: float = 0.2,
        entropy_coefficient: float = 0.01,
    ):
        """
        Constructor for the PPO class.

        Parameters
        ----------
        value_function : Callable
            A the state value function that computes the value of a series of states for
            using the reward of the trajectory visiting these states
        n_epochs : int
            number of PPO updates
        epsilon : float
            the maximum of the relative distance between old and updated policy.
        entropy_coefficient : float
            Entropy coefficient for the PPO update. # TODO Add more here.

        """
        self.value_function = value_function
        self.sampling_strategy = sampling_strategy
        self.n_epochs = n_epochs
        self.epsilon = epsilon
        self.entropy_coefficient = entropy_coefficient
        self.eps = 1e-8

    @partial(jit, static_argnums=(0, 2))
    def _calculate_loss(
        self,
        network_params: FrozenDict,
        network: Network,
        feature_data,
        action_indices,
        rewards,
        old_log_probs,
    ) -> jnp.array:
        """
        A function that computes the actor loss.

        Parameters
        ----------
        network : FlaxModel
            The actor-critic network that approximates the policy.
        network_params : FrozenDict
            Parameters of the actor-critic model used.
        feature_data : np.ndarray (n_time_steps, n_particles, feature_dimension)
            Observable data for each time step and particle within the episode.
        action_indices : np.ndarray (n_time_steps, n_particles)
            The actions taken by the policy for all time steps and particles during one
            episode.
        rewards : np.ndarray (n_time_steps, n_particles)
            The rewards received for all time steps and particles during one episode.
        old_log_probs : np.ndarray (n_time_steps, n_particles)
            The log probabilities of the actions taken by the policy for all time steps
            and particles during one episode.

        Returns
        -------
        loss: float
            The loss of the actor-critic network for the last episode.
        """

        # compute the probabilities of the old actions under the new policy
        new_logits, predicted_values = network(network_params, feature_data)
        predicted_values = predicted_values.squeeze()

        # compute the advantages and returns
        advantages, returns = self.value_function(
            rewards=rewards, values=predicted_values
        )

        # compute the probabilities of the old actions under the new policy
        new_probabilities = jax.nn.softmax(new_logits, axis=-1)

        # compute the entropy of the whole distribution
        entropy = self.sampling_strategy.compute_entropy(new_probabilities).sum()
        chosen_log_probs = jnp.log(
            gather_n_dim_indices(new_probabilities, action_indices) + self.eps
        )

        # compute the ratio between old and new probs
        ratio = jnp.exp(chosen_log_probs - old_log_probs)

        # Compute the actor loss

        # compute the clipped loss
        clipped_loss = -1 * jnp.minimum(
            ratio * advantages,
            jnp.clip(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantages,
        )
        particle_actor_loss = jnp.sum(clipped_loss, axis=0)
        actor_loss = jnp.sum(particle_actor_loss)

        # Compute critic loss
        total_critic_loss = (
            optax.huber_loss(predicted_values, returns).sum(axis=0).sum()
        )

        # Compute combined loss
        loss = actor_loss - self.entropy_coefficient * entropy + 0.5 * total_critic_loss

        return loss

    def compute_loss(self, network: Network, episode_data):
        """
        Compute the loss and update the shared actor-critic network.

        Parameters
        ----------
        network : Network
                actor-critic model to use in the analysis.
        episode_data : np.ndarray (n_timesteps, n_particles, feature_dimension)
                Observable data for each time step and particle within the episode.

        Returns
        -------

        """
        old_log_probs_data = jnp.array(episode_data.log_probs)
        feature_data = jnp.array(episode_data.features)
        action_data = jnp.array(episode_data.actions)
        reward_data = jnp.array(episode_data.rewards)

        for _ in range(self.n_epochs):
            network_grad_fn = jax.value_and_grad(self._calculate_loss)
            _, network_grad = network_grad_fn(
                network.model_state.params,
                network=network,
                feature_data=feature_data,
                action_indices=action_data,
                rewards=reward_data,
                old_log_probs=old_log_probs_data,
            )

            network.update_model(network_grad)

__init__(value_function=GAE(), sampling_strategy=GumbelDistribution(), n_epochs=20, epsilon=0.2, entropy_coefficient=0.01)

Constructor for the PPO class.

Parameters

value_function : Callable A the state value function that computes the value of a series of states for using the reward of the trajectory visiting these states n_epochs : int number of PPO updates epsilon : float the maximum of the relative distance between old and updated policy. entropy_coefficient : float Entropy coefficient for the PPO update. # TODO Add more here.

Source code in swarmrl/losses/proximal_policy_loss.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(
    self,
    value_function: GAE = GAE(),
    sampling_strategy: SamplingStrategy = GumbelDistribution(),
    n_epochs: int = 20,
    epsilon: float = 0.2,
    entropy_coefficient: float = 0.01,
):
    """
    Constructor for the PPO class.

    Parameters
    ----------
    value_function : Callable
        A the state value function that computes the value of a series of states for
        using the reward of the trajectory visiting these states
    n_epochs : int
        number of PPO updates
    epsilon : float
        the maximum of the relative distance between old and updated policy.
    entropy_coefficient : float
        Entropy coefficient for the PPO update. # TODO Add more here.

    """
    self.value_function = value_function
    self.sampling_strategy = sampling_strategy
    self.n_epochs = n_epochs
    self.epsilon = epsilon
    self.entropy_coefficient = entropy_coefficient
    self.eps = 1e-8

compute_loss(network, episode_data)

Compute the loss and update the shared actor-critic network.

Parameters

network : Network actor-critic model to use in the analysis. episode_data : np.ndarray (n_timesteps, n_particles, feature_dimension) Observable data for each time step and particle within the episode.

Returns
Source code in swarmrl/losses/proximal_policy_loss.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def compute_loss(self, network: Network, episode_data):
    """
    Compute the loss and update the shared actor-critic network.

    Parameters
    ----------
    network : Network
            actor-critic model to use in the analysis.
    episode_data : np.ndarray (n_timesteps, n_particles, feature_dimension)
            Observable data for each time step and particle within the episode.

    Returns
    -------

    """
    old_log_probs_data = jnp.array(episode_data.log_probs)
    feature_data = jnp.array(episode_data.features)
    action_data = jnp.array(episode_data.actions)
    reward_data = jnp.array(episode_data.rewards)

    for _ in range(self.n_epochs):
        network_grad_fn = jax.value_and_grad(self._calculate_loss)
        _, network_grad = network_grad_fn(
            network.model_state.params,
            network=network,
            feature_data=feature_data,
            action_indices=action_data,
            rewards=reward_data,
            old_log_probs=old_log_probs_data,
        )

        network.update_model(network_grad)