Skip to content

swarmrl.agents.actor_critic Module API Reference

Module for the Actor-Critic RL protocol.

ActorCriticAgent

Bases: Agent

Class to handle the actor-critic RL Protocol.

Source code in swarmrl/agents/actor_critic.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
class ActorCriticAgent(Agent):
    """
    Class to handle the actor-critic RL Protocol.
    """

    def __init__(
        self,
        particle_type: int,
        network: Network,
        task: Task,
        observable: Observable,
        actions: dict,
        loss: Loss = ProximalPolicyLoss(),
        train: bool = True,
    ):
        """
        Constructor for the actor-critic protocol.

        Parameters
        ----------
        particle_type : int
                Particle ID this RL protocol applies to.
        observable : Observable
                Observable for this particle type and network input
        task : Task
                Task for this particle type to perform.
        actions : dict
                Actions allowed for the particle.
        loss : Loss (default=ProximalPolicyLoss)
                Loss function to use to update the networks.
        train : bool (default=True)
                Flag to indicate if the agent is training.
        """
        # Properties of the agent.
        self.network = network
        self.particle_type = particle_type
        self.task = task
        self.observable = observable
        self.actions = actions
        self.train = train
        self.loss = loss

        # Trajectory to be updated.
        self.trajectory = TrajectoryInformation(particle_type=self.particle_type)

    def __name__(self) -> str:
        """
        Give the class a name.

        Return
        ------
        name : str
            Name of the class.
        """
        return "ActorCriticAgent"

    def update_agent(self) -> tuple:
        """
        Update the agents network.

        Returns
        -------
        rewards : float
                Net reward for the agent.
        killed : bool
                Whether or not this agent killed the
                simulation.
        """
        # Collect data for returns.
        rewards = self.trajectory.rewards
        killed = self.trajectory.killed

        # Compute loss for actor and critic.
        self.loss.compute_loss(
            network=self.network,
            episode_data=self.trajectory,
        )

        # Reset the trajectory storage.
        self.reset_trajectory()

        return rewards, killed

    def reset_agent(self, colloids: typing.List[Colloid]):
        """
        Reset several properties of the agent.

        Reset the observables and tasks for the agent.

        Parameters
        ----------
        colloids : typing.List[Colloid]
                Colloids to use in the initialization.
        """
        self.observable.initialize(colloids)
        self.task.initialize(colloids)

    def reset_trajectory(self):
        """
        Set all trajectory data to None.
        """
        self.task.kill_switch = False  # Reset here.
        self.trajectory = TrajectoryInformation(particle_type=self.particle_type)

    def initialize_network(self):
        """
        Initialize all of the models in the gym.
        """
        self.network.reinitialize_network()

    def save_agent(self, directory: str):
        """
        Save the agent network state.

        Parameters
        ----------
        directory : str
                Location to save the models.
        """
        self.network.export_model(
            filename=f"{self.__name__()}_{self.particle_type}", directory=directory
        )

    def restore_agent(self, directory: str):
        """
        Restore the agent state from a directory.
        """
        self.network.restore_model_state(
            filename=f"{self.__name__()}_{self.particle_type}", directory=directory
        )

    def calc_action(self, colloids: typing.List[Colloid]) -> typing.List[Action]:
        """
        Copmute the new state for the agent.

        Returns the chosen actions to the force function which
        talks to the espresso engine.

        Parameters
        ----------
        colloids : List[Colloid]
                List of colloids in the system.
        """
        state_description = self.observable.compute_observable(colloids)
        action_indices, log_probs = self.network.compute_action(
            observables=state_description
        )
        chosen_actions = np.take(list(self.actions.values()), action_indices, axis=-1)

        # Update the trajectory information.
        if self.train:
            self.trajectory.features.append(state_description)
            self.trajectory.actions.append(action_indices)
            self.trajectory.log_probs.append(log_probs)
            self.trajectory.rewards.append(self.task(colloids))
            self.trajectory.killed = self.task.kill_switch

        self.kill_switch = self.task.kill_switch

        return chosen_actions

__init__(particle_type, network, task, observable, actions, loss=ProximalPolicyLoss(), train=True)

Constructor for the actor-critic protocol.

Parameters

particle_type : int Particle ID this RL protocol applies to. observable : Observable Observable for this particle type and network input task : Task Task for this particle type to perform. actions : dict Actions allowed for the particle. loss : Loss (default=ProximalPolicyLoss) Loss function to use to update the networks. train : bool (default=True) Flag to indicate if the agent is training.

Source code in swarmrl/agents/actor_critic.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(
    self,
    particle_type: int,
    network: Network,
    task: Task,
    observable: Observable,
    actions: dict,
    loss: Loss = ProximalPolicyLoss(),
    train: bool = True,
):
    """
    Constructor for the actor-critic protocol.

    Parameters
    ----------
    particle_type : int
            Particle ID this RL protocol applies to.
    observable : Observable
            Observable for this particle type and network input
    task : Task
            Task for this particle type to perform.
    actions : dict
            Actions allowed for the particle.
    loss : Loss (default=ProximalPolicyLoss)
            Loss function to use to update the networks.
    train : bool (default=True)
            Flag to indicate if the agent is training.
    """
    # Properties of the agent.
    self.network = network
    self.particle_type = particle_type
    self.task = task
    self.observable = observable
    self.actions = actions
    self.train = train
    self.loss = loss

    # Trajectory to be updated.
    self.trajectory = TrajectoryInformation(particle_type=self.particle_type)

__name__()

Give the class a name.

Return

name : str Name of the class.

Source code in swarmrl/agents/actor_critic.py
78
79
80
81
82
83
84
85
86
87
def __name__(self) -> str:
    """
    Give the class a name.

    Return
    ------
    name : str
        Name of the class.
    """
    return "ActorCriticAgent"

calc_action(colloids)

Copmute the new state for the agent.

Returns the chosen actions to the force function which talks to the espresso engine.

Parameters

colloids : List[Colloid] List of colloids in the system.

Source code in swarmrl/agents/actor_critic.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def calc_action(self, colloids: typing.List[Colloid]) -> typing.List[Action]:
    """
    Copmute the new state for the agent.

    Returns the chosen actions to the force function which
    talks to the espresso engine.

    Parameters
    ----------
    colloids : List[Colloid]
            List of colloids in the system.
    """
    state_description = self.observable.compute_observable(colloids)
    action_indices, log_probs = self.network.compute_action(
        observables=state_description
    )
    chosen_actions = np.take(list(self.actions.values()), action_indices, axis=-1)

    # Update the trajectory information.
    if self.train:
        self.trajectory.features.append(state_description)
        self.trajectory.actions.append(action_indices)
        self.trajectory.log_probs.append(log_probs)
        self.trajectory.rewards.append(self.task(colloids))
        self.trajectory.killed = self.task.kill_switch

    self.kill_switch = self.task.kill_switch

    return chosen_actions

initialize_network()

Initialize all of the models in the gym.

Source code in swarmrl/agents/actor_critic.py
137
138
139
140
141
def initialize_network(self):
    """
    Initialize all of the models in the gym.
    """
    self.network.reinitialize_network()

reset_agent(colloids)

Reset several properties of the agent.

Reset the observables and tasks for the agent.

Parameters

colloids : typing.List[Colloid] Colloids to use in the initialization.

Source code in swarmrl/agents/actor_critic.py
116
117
118
119
120
121
122
123
124
125
126
127
128
def reset_agent(self, colloids: typing.List[Colloid]):
    """
    Reset several properties of the agent.

    Reset the observables and tasks for the agent.

    Parameters
    ----------
    colloids : typing.List[Colloid]
            Colloids to use in the initialization.
    """
    self.observable.initialize(colloids)
    self.task.initialize(colloids)

reset_trajectory()

Set all trajectory data to None.

Source code in swarmrl/agents/actor_critic.py
130
131
132
133
134
135
def reset_trajectory(self):
    """
    Set all trajectory data to None.
    """
    self.task.kill_switch = False  # Reset here.
    self.trajectory = TrajectoryInformation(particle_type=self.particle_type)

restore_agent(directory)

Restore the agent state from a directory.

Source code in swarmrl/agents/actor_critic.py
156
157
158
159
160
161
162
def restore_agent(self, directory: str):
    """
    Restore the agent state from a directory.
    """
    self.network.restore_model_state(
        filename=f"{self.__name__()}_{self.particle_type}", directory=directory
    )

save_agent(directory)

Save the agent network state.

Parameters

directory : str Location to save the models.

Source code in swarmrl/agents/actor_critic.py
143
144
145
146
147
148
149
150
151
152
153
154
def save_agent(self, directory: str):
    """
    Save the agent network state.

    Parameters
    ----------
    directory : str
            Location to save the models.
    """
    self.network.export_model(
        filename=f"{self.__name__()}_{self.particle_type}", directory=directory
    )

update_agent()

Update the agents network.

Returns

rewards : float Net reward for the agent. killed : bool Whether or not this agent killed the simulation.

Source code in swarmrl/agents/actor_critic.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def update_agent(self) -> tuple:
    """
    Update the agents network.

    Returns
    -------
    rewards : float
            Net reward for the agent.
    killed : bool
            Whether or not this agent killed the
            simulation.
    """
    # Collect data for returns.
    rewards = self.trajectory.rewards
    killed = self.trajectory.killed

    # Compute loss for actor and critic.
    self.loss.compute_loss(
        network=self.network,
        episode_data=self.trajectory,
    )

    # Reset the trajectory storage.
    self.reset_trajectory()

    return rewards, killed

TrajectoryInformation dataclass

Helper dataclass for training RL models.

Source code in swarmrl/agents/actor_critic.py
19
20
21
22
23
24
25
26
27
28
29
30
@dataclass
class TrajectoryInformation:
    """
    Helper dataclass for training RL models.
    """

    particle_type: int
    features: list = field(default_factory=list)
    actions: list = field(default_factory=list)
    log_probs: list = field(default_factory=list)
    rewards: list = field(default_factory=list)
    killed: bool = False