swarmrl.networks.network Module API Reference¶
Parent class for the networks.
Network
¶
A parent class for the networks that will be used.
Source code in swarmrl/networks/network.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
|
__call__(params, feature_vector)
¶
Perform the forward pass on the model. This method is used in the update. It uses a vmapped version of the model.apply function.
Parameters¶
params : FrozenDict Parameters of the model. feature_vector : np.ndarray Current state of the agent on which actions should be made.
Returns¶
Source code in swarmrl/networks/network.py
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
|
compute_action(observables, explore_mode=False)
¶
Compute and action from the action space.
This method computes an action on all colloids of the relevent type.
Parameters¶
observables : List[Colloid] Colloids in the system for which the action should be computed. explore_mode : bool If true, an exploration vs exploitation function is called.
Returns¶
action : int An integer bounded between 0 and the number of output neurons corresponding to the action chosen by the agent.
Source code in swarmrl/networks/network.py
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
|
export_model(filename='models', directory='Models')
¶
Export the model state to a directory.
Parameters¶
filename : str (default=models) Name of the file the models are saved in. directory : str (default=Models) Directory in which to save the models. If the directory is not in the currently directory, it will be created.
Source code in swarmrl/networks/network.py
58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
|
restore_model_state(filename, directory)
¶
Restore the model state from a file.
Parameters¶
filename : str Name of the model state file. directory : str Path to the model state file.
Returns¶
Updates the model state.
Source code in swarmrl/networks/network.py
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
|
update_model(grads)
¶
Train the model.
For jax model grads are used to update a model state directly. This method takes the grads and updates the params dict corresponding to the relevant model.
Parameters¶
grads : dict Dict of grads from a jax value_and_grad call.
Source code in swarmrl/networks/network.py
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
|