Skip to content

swarmrl.exploration_policies.random_exploration Module API Reference

Random exploration module.

RandomExploration

Bases: ExplorationPolicy, ABC

Perform exploration by random moves.

Source code in swarmrl/exploration_policies/random_exploration.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class RandomExploration(ExplorationPolicy, ABC):
    """
    Perform exploration by random moves.
    """

    def __init__(self, probability: float = 0.1):
        """
        Constructor for the random exploration module.

        Parameters
        ----------
        probability : float
                Probability that a random action will be chosen.
                Bound between [0.0, 1.0]
        """
        self.probability = probability

    @partial(jax.jit, static_argnums=(0,))
    def __call__(
        self, model_actions: np.ndarray, action_space_length: int, seed
    ) -> np.ndarray:
        """
        Return an index associated with the chosen action.

        Parameters
        ----------
        model_actions : np.ndarray (n_colloids,)
                Action chosen by the model for each colloid.
        action_space_length : int
                Number of possible actions. Should be 1 higher than the actual highest
                index, i.e if I have actions [0, 1, 2, 3] this number should be 4.

        Returns
        -------
        action : np.ndarray
                Action chosen after the exploration module has operated for
                each colloid.
        """
        key = jax.random.PRNGKey(seed)
        sample = jax.random.uniform(key, shape=model_actions.shape)

        to_be_changed = np.clip(sample - self.probability, a_min=0, a_max=1)
        to_be_changed = np.clip(to_be_changed * 1e6, a_min=0, a_max=1)
        not_to_be_changed = np.clip(to_be_changed * -10 + 1, 0, 1)

        # Choose random actions
        key, subkey = jax.random.split(key)
        exploration_actions = jax.random.randint(
            subkey,
            shape=(model_actions.shape[0],),
            minval=0,
            maxval=action_space_length,
        )

        # Put the new actions in.
        model_actions = (
            model_actions * to_be_changed + exploration_actions * not_to_be_changed
        ).astype(np.int16)

        return model_actions

__call__(model_actions, action_space_length, seed)

Return an index associated with the chosen action.

Parameters

model_actions : np.ndarray (n_colloids,) Action chosen by the model for each colloid. action_space_length : int Number of possible actions. Should be 1 higher than the actual highest index, i.e if I have actions [0, 1, 2, 3] this number should be 4.

Returns

action : np.ndarray Action chosen after the exploration module has operated for each colloid.

Source code in swarmrl/exploration_policies/random_exploration.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@partial(jax.jit, static_argnums=(0,))
def __call__(
    self, model_actions: np.ndarray, action_space_length: int, seed
) -> np.ndarray:
    """
    Return an index associated with the chosen action.

    Parameters
    ----------
    model_actions : np.ndarray (n_colloids,)
            Action chosen by the model for each colloid.
    action_space_length : int
            Number of possible actions. Should be 1 higher than the actual highest
            index, i.e if I have actions [0, 1, 2, 3] this number should be 4.

    Returns
    -------
    action : np.ndarray
            Action chosen after the exploration module has operated for
            each colloid.
    """
    key = jax.random.PRNGKey(seed)
    sample = jax.random.uniform(key, shape=model_actions.shape)

    to_be_changed = np.clip(sample - self.probability, a_min=0, a_max=1)
    to_be_changed = np.clip(to_be_changed * 1e6, a_min=0, a_max=1)
    not_to_be_changed = np.clip(to_be_changed * -10 + 1, 0, 1)

    # Choose random actions
    key, subkey = jax.random.split(key)
    exploration_actions = jax.random.randint(
        subkey,
        shape=(model_actions.shape[0],),
        minval=0,
        maxval=action_space_length,
    )

    # Put the new actions in.
    model_actions = (
        model_actions * to_be_changed + exploration_actions * not_to_be_changed
    ).astype(np.int16)

    return model_actions

__init__(probability=0.1)

Constructor for the random exploration module.

Parameters

probability : float Probability that a random action will be chosen. Bound between [0.0, 1.0]

Source code in swarmrl/exploration_policies/random_exploration.py
19
20
21
22
23
24
25
26
27
28
29
def __init__(self, probability: float = 0.1):
    """
    Constructor for the random exploration module.

    Parameters
    ----------
    probability : float
            Probability that a random action will be chosen.
            Bound between [0.0, 1.0]
    """
    self.probability = probability