Skip to content

swarmrl.networks.network Module API Reference

Parent class for the networks.

Network

A parent class for the networks that will be used.

Source code in swarmrl/networks/network.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
class Network:
    """
    A parent class for the networks that will be used.
    """

    def compute_action(self, observables: List[Colloid], explore_mode: bool = False):
        """
        Compute and action from the action space.

        This method computes an action on all colloids of the relevent type.

        Parameters
        ----------
        observables : List[Colloid]
                Colloids in the system for which the action should be computed.
        explore_mode : bool
                If true, an exploration vs exploitation function is called.

        Returns
        -------
        action : int
                An integer bounded between 0 and the number of output neurons
                corresponding to the action chosen by the agent.
        """
        raise NotImplementedError("Implemented in child class.")

    def __call__(self, params: FrozenDict, feature_vector: np.ndarray):
        """
        Perform the forward pass on the model. This method is
        used in the update. It uses a vmapped version of the
        model.apply function.

        Parameters
        ----------
        params : FrozenDict
                Parameters of the model.
        feature_vector : np.ndarray
                Current state of the agent on which actions should be made.

        Returns
        -------

        """
        raise NotImplementedError("Implemented in child class.")

    def export_model(self, filename: str = "models", directory: str = "Models"):
        """
        Export the model state to a directory.

        Parameters
        ----------
        filename : str (default=models)
                Name of the file the models are saved in.
        directory : str (default=Models)
                Directory in which to save the models. If the directory is not
                in the currently directory, it will be created.

        """
        raise NotImplementedError("Implemented in child class")

    def restore_model_state(self, filename, directory):
        """
        Restore the model state from a file.

        Parameters
        ----------
        filename : str
                Name of the model state file.
        directory : str
                Path to the model state file.

        Returns
        -------
        Updates the model state.
        """
        raise NotImplementedError("Implemented in child class")

    def update_model(
        self,
        grads,
    ):
        """
        Train the model.

        For jax model grads are used to update a model state directly. This method
        takes the grads and updates the params dict corresponding to the relevant
        model.

        Parameters
        ----------
        grads : dict
                Dict of grads from a jax value_and_grad call.
        """
        raise NotImplementedError("Implemented in child class.")

__call__(params, feature_vector)

Perform the forward pass on the model. This method is used in the update. It uses a vmapped version of the model.apply function.

Parameters

params : FrozenDict Parameters of the model. feature_vector : np.ndarray Current state of the agent on which actions should be made.

Returns
Source code in swarmrl/networks/network.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def __call__(self, params: FrozenDict, feature_vector: np.ndarray):
    """
    Perform the forward pass on the model. This method is
    used in the update. It uses a vmapped version of the
    model.apply function.

    Parameters
    ----------
    params : FrozenDict
            Parameters of the model.
    feature_vector : np.ndarray
            Current state of the agent on which actions should be made.

    Returns
    -------

    """
    raise NotImplementedError("Implemented in child class.")

compute_action(observables, explore_mode=False)

Compute and action from the action space.

This method computes an action on all colloids of the relevent type.

Parameters

observables : List[Colloid] Colloids in the system for which the action should be computed. explore_mode : bool If true, an exploration vs exploitation function is called.

Returns

action : int An integer bounded between 0 and the number of output neurons corresponding to the action chosen by the agent.

Source code in swarmrl/networks/network.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def compute_action(self, observables: List[Colloid], explore_mode: bool = False):
    """
    Compute and action from the action space.

    This method computes an action on all colloids of the relevent type.

    Parameters
    ----------
    observables : List[Colloid]
            Colloids in the system for which the action should be computed.
    explore_mode : bool
            If true, an exploration vs exploitation function is called.

    Returns
    -------
    action : int
            An integer bounded between 0 and the number of output neurons
            corresponding to the action chosen by the agent.
    """
    raise NotImplementedError("Implemented in child class.")

export_model(filename='models', directory='Models')

Export the model state to a directory.

Parameters

filename : str (default=models) Name of the file the models are saved in. directory : str (default=Models) Directory in which to save the models. If the directory is not in the currently directory, it will be created.

Source code in swarmrl/networks/network.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def export_model(self, filename: str = "models", directory: str = "Models"):
    """
    Export the model state to a directory.

    Parameters
    ----------
    filename : str (default=models)
            Name of the file the models are saved in.
    directory : str (default=Models)
            Directory in which to save the models. If the directory is not
            in the currently directory, it will be created.

    """
    raise NotImplementedError("Implemented in child class")

restore_model_state(filename, directory)

Restore the model state from a file.

Parameters

filename : str Name of the model state file. directory : str Path to the model state file.

Returns

Updates the model state.

Source code in swarmrl/networks/network.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def restore_model_state(self, filename, directory):
    """
    Restore the model state from a file.

    Parameters
    ----------
    filename : str
            Name of the model state file.
    directory : str
            Path to the model state file.

    Returns
    -------
    Updates the model state.
    """
    raise NotImplementedError("Implemented in child class")

update_model(grads)

Train the model.

For jax model grads are used to update a model state directly. This method takes the grads and updates the params dict corresponding to the relevant model.

Parameters

grads : dict Dict of grads from a jax value_and_grad call.

Source code in swarmrl/networks/network.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def update_model(
    self,
    grads,
):
    """
    Train the model.

    For jax model grads are used to update a model state directly. This method
    takes the grads and updates the params dict corresponding to the relevant
    model.

    Parameters
    ----------
    grads : dict
            Dict of grads from a jax value_and_grad call.
    """
    raise NotImplementedError("Implemented in child class.")