Skip to content

swarmrl.sampling_strategies.gumbel_distribution Module API Reference

Module for the Gumbel distribution.

GumbelDistribution

Bases: SamplingStrategy, ABC

Class for the Gumbel distribution.

Source code in swarmrl/sampling_strategies/gumbel_distribution.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class GumbelDistribution(SamplingStrategy, ABC):
    """
    Class for the Gumbel distribution.
    """

    def __call__(self, logits: np.ndarray) -> np.ndarray:
        """
        Sample from the distribution.

        Parameters
        ----------
        logits : np.ndarray (n_colloids, n_dimensions)
                Logits from the model to use in the computation for all colloids.

        Returns
        -------
        indices : np.ndarray (n_colloids,)
                Indeices of chosen actions for all colloids.

        Notes
        -----
        See https://arxiv.org/abs/1611.01144 for more information.
        """
        rng = jax.random.PRNGKey(onp.random.randint(0, 1236534623))
        noise = jax.random.uniform(rng, shape=logits.shape)

        indices = np.argmax(logits - np.log(-np.log(noise)), axis=-1)

        return indices

__call__(logits)

Sample from the distribution.

Parameters

logits : np.ndarray (n_colloids, n_dimensions) Logits from the model to use in the computation for all colloids.

Returns

indices : np.ndarray (n_colloids,) Indeices of chosen actions for all colloids.

Notes

See https://arxiv.org/abs/1611.01144 for more information.

Source code in swarmrl/sampling_strategies/gumbel_distribution.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __call__(self, logits: np.ndarray) -> np.ndarray:
    """
    Sample from the distribution.

    Parameters
    ----------
    logits : np.ndarray (n_colloids, n_dimensions)
            Logits from the model to use in the computation for all colloids.

    Returns
    -------
    indices : np.ndarray (n_colloids,)
            Indeices of chosen actions for all colloids.

    Notes
    -----
    See https://arxiv.org/abs/1611.01144 for more information.
    """
    rng = jax.random.PRNGKey(onp.random.randint(0, 1236534623))
    noise = jax.random.uniform(rng, shape=logits.shape)

    indices = np.argmax(logits - np.log(-np.log(noise)), axis=-1)

    return indices