Skip to content

swarmrl.sampling_strategies.categorical_distribution Module API Reference

Module for the categorical distribution.

CategoricalDistribution

Bases: SamplingStrategy, ABC

Class for the Gumbel distribution.

Source code in swarmrl/sampling_strategies/categorical_distribution.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class CategoricalDistribution(SamplingStrategy, ABC):
    """
    Class for the Gumbel distribution.
    """

    def __init__(self, noise: str = "none"):
        """
        Constructor for the categorical distribution.

        Parameters
        ----------
        noise : str
                Noise method to use, options include none, uniform and gaussian.
        """
        noise_dict = {
            "uniform": jax.random.uniform,
            "gaussian": jax.random.normal,
            "none": None,
        }
        try:
            self.noise = noise_dict[noise]
        except KeyError:
            msg = (
                f"Parsed noise method {noise} is not implemented, please choose"
                "from 'none', 'gaussian' and 'uniform'."
            )
            raise KeyError(msg)

    def __call__(self, logits: np.ndarray) -> np.ndarray:
        """
        Sample from the distribution.

        Parameters
        ----------
        logits : np.ndarray (n_colloids, n_dimensions)
                Logits from the model to use in the computation for all colloids.
        entropy : bool
                If true, the Shannon entropy of the distribution is returned.

        Returns
        -------
        indices : np.ndarray (n_colloids,)
                Index of the selected option in the distribution.
        """
        rng = jax.random.PRNGKey(onp.random.randint(0, 1236534623))

        try:
            noise = self.noise(rng, shape=logits.shape)
        except TypeError:
            # If set to None the noise is just 0
            noise = 0

        indices = jax.random.categorical(rng, logits=logits + noise)

        return indices

__call__(logits)

Sample from the distribution.

Parameters

logits : np.ndarray (n_colloids, n_dimensions) Logits from the model to use in the computation for all colloids. entropy : bool If true, the Shannon entropy of the distribution is returned.

Returns

indices : np.ndarray (n_colloids,) Index of the selected option in the distribution.

Source code in swarmrl/sampling_strategies/categorical_distribution.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def __call__(self, logits: np.ndarray) -> np.ndarray:
    """
    Sample from the distribution.

    Parameters
    ----------
    logits : np.ndarray (n_colloids, n_dimensions)
            Logits from the model to use in the computation for all colloids.
    entropy : bool
            If true, the Shannon entropy of the distribution is returned.

    Returns
    -------
    indices : np.ndarray (n_colloids,)
            Index of the selected option in the distribution.
    """
    rng = jax.random.PRNGKey(onp.random.randint(0, 1236534623))

    try:
        noise = self.noise(rng, shape=logits.shape)
    except TypeError:
        # If set to None the noise is just 0
        noise = 0

    indices = jax.random.categorical(rng, logits=logits + noise)

    return indices

__init__(noise='none')

Constructor for the categorical distribution.

Parameters

noise : str Noise method to use, options include none, uniform and gaussian.

Source code in swarmrl/sampling_strategies/categorical_distribution.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def __init__(self, noise: str = "none"):
    """
    Constructor for the categorical distribution.

    Parameters
    ----------
    noise : str
            Noise method to use, options include none, uniform and gaussian.
    """
    noise_dict = {
        "uniform": jax.random.uniform,
        "gaussian": jax.random.normal,
        "none": None,
    }
    try:
        self.noise = noise_dict[noise]
    except KeyError:
        msg = (
            f"Parsed noise method {noise} is not implemented, please choose"
            "from 'none', 'gaussian' and 'uniform'."
        )
        raise KeyError(msg)