Skip to content

swarmrl.networks.flax_network Module API Reference

Jax model for reinforcement learning.

FlaxModel

Bases: Network, ABC

Class for the Flax model in ZnRND.

Attributes

epoch_count : int Current epoch stage. Used in saving the models.

Source code in swarmrl/networks/flax_network.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
class FlaxModel(Network, ABC):
    """
    Class for the Flax model in ZnRND.

    Attributes
    ----------
    epoch_count : int
            Current epoch stage. Used in saving the models.
    """

    def __init__(
        self,
        flax_model: nn.Module,
        input_shape: tuple,
        optimizer: GradientTransformation = None,
        exploration_policy: ExplorationPolicy = RandomExploration(probability=0.0),
        sampling_strategy: SamplingStrategy = GumbelDistribution(),
        rng_key: int = None,
        deployment_mode: bool = False,
    ):
        """
        Constructor for a Flax model.

        Parameters
        ----------
        flax_model : nn.Module
                Flax model as a neural network.
        optimizer : Callable
                optimizer to use in the training. OpTax is used by default and
                cross-compatibility is not assured.
        input_shape : tuple
                Shape of the NN input.
        rng_key : int
                Key to seed the model with. Default is a randomly generated key but
                the parameter is here for testing purposes.
        deployment_mode : bool
                If true, the model is a shell for the network and nothing else. No
                training can be performed, this is only used in deployment.
        """
        if rng_key is None:
            rng_key = onp.random.randint(0, 1027465782564)
        self.sampling_strategy = sampling_strategy
        self.model = flax_model
        self.apply_fn = jax.jit(
            jax.vmap(self.model.apply, in_axes=(None, 0))
        )  # Map over agents
        self.batch_apply_fn = jax.jit(jax.vmap(self.apply_fn, in_axes=(None, 0)))
        self.input_shape = input_shape
        self.model_state = None

        if not deployment_mode:
            self.optimizer = optimizer
            self.exploration_policy = exploration_policy

            # initialize the model state
            init_rng = jax.random.PRNGKey(rng_key)
            _, subkey = jax.random.split(init_rng)
            self.model_state = self._create_train_state(subkey)

            self.epoch_count = 0

    def _create_custom_train_state(self, optimizer: dict):
        """
        Deal with the optimizers in case of complex configuration.
        """
        return type("TrainState", (TrainState,), optimizer)

    def _create_train_state(self, init_rng: int) -> TrainState:
        """
        Create a training state of the model.

        Parameters
        ----------
        init_rng : int
                Initial rng for train state that is immediately deleted.

        Returns
        -------
        state : TrainState / CustomTrainState
                initial state of model to then be trained.
                If you have multiple optimizers, this will create a custom train state.
        """
        params = self.model.init(init_rng, np.ones(list(self.input_shape)))["params"]

        if isinstance(self.optimizer, dict):
            CustomTrainState = self._create_custom_train_state(self.optimizer)

            return CustomTrainState.create(
                apply_fn=self.model.apply, params=params, tx=self.optimizer
            )
        else:
            return TrainState.create(
                apply_fn=self.model.apply, params=params, tx=self.optimizer
            )

    def reinitialize_network(self):
        """
        Initialize the neural network.
        """
        rng_key = onp.random.randint(0, 1027465782564)
        init_rng = jax.random.PRNGKey(rng_key)
        _, subkey = jax.random.split(init_rng)
        self.model_state = self._create_train_state(subkey)

    def update_model(self, grads):
        """
        Train the model.

        See the parent class for a full doc-string.
        """
        # Logging for grads and pre-train model state
        logger.debug(f"{grads=}")
        logger.debug(f"{self.model_state=}")

        if isinstance(self.optimizer, dict):
            pass

        else:
            self.model_state = self.model_state.apply_gradients(grads=grads)

        # Logging for post-train model state
        logger.debug(f"{self.model_state=}")

        self.epoch_count += 1

    def compute_action(self, observables: List):
        """
        Compute and action from the action space.

        This method computes an action on all colloids of the relevant type.

        Parameters
        ----------
        observables : List (n_agents, observable_dimension)
                Observable for each colloid for which the action should be computed.

        Returns
        -------
        tuple : (np.ndarray, np.ndarray)
                The first element is an array of indices corresponding to the action
                taken by the agent. The value is bounded between 0 and the number of
                output neurons. The second element is an array of the corresponding
                log_probs (i.e. the output of the network put through a softmax).
        """
        # Compute state
        try:
            logits, _ = self.apply_fn(
                {"params": self.model_state.params}, np.array(observables)
            )
        except AttributeError:  # We need this for loaded models.
            logits, _ = self.apply_fn(
                {"params": self.model_state["params"]}, np.array(observables)
            )
        logger.debug(f"{logits=}")  # (n_colloids, n_actions)

        # Compute the action
        indices = self.sampling_strategy(logits)
        # Add a small value to the log_probs to avoid log(0) errors.
        eps = 1e-8
        log_probs = np.log(jax.nn.softmax(logits) + eps)

        indices = self.exploration_policy(
            indices, logits.shape[-1], onp.random.randint(8759865)
        )
        return (
            indices,
            np.take_along_axis(log_probs, indices.reshape(-1, 1), axis=1).reshape(-1),
        )

    def export_model(self, filename: str = "model", directory: str = "Models"):
        """
        Export the model state to a directory.

        Parameters
        ----------
        filename : str (default=models)
                Name of the file the models are saved in.
        directory : str (default=Models)
                Directory in which to save the models. If the directory is not
                in the currently directory, it will be created.

        """
        model_params = self.model_state.params
        opt_state = self.model_state.opt_state
        opt_step = self.model_state.step
        epoch = self.epoch_count

        os.makedirs(directory, exist_ok=True)

        with open(directory + "/" + filename + ".pkl", "wb") as f:
            pickle.dump((model_params, opt_state, opt_step, epoch), f)

    def restore_model_state(self, filename, directory):
        """
        Restore the model state from a file.

        Parameters
        ----------
        filename : str
                Name of the model state file
        directory : str
                Path to the model state file.

        Returns
        -------
        Updates the model state.
        """

        with open(directory + "/" + filename + ".pkl", "rb") as f:
            model_params, opt_state, opt_step, epoch = pickle.load(f)

        self.model_state = self.model_state.replace(
            params=model_params, opt_state=opt_state, step=opt_step
        )
        self.epoch_count = epoch

    def __call__(self, params: FrozenDict, episode_features):
        """
        vmaped version of the model call function.
        Operates on a batch of episodes.

        Parameters
        ----------
        parmas : dict
                Parameters of the model.
        episode_features: np.ndarray (n_steps, n_agents, observable_dimension)
                Features of the episode. This contains the features of all agents,
                for all time steps in the episode.


        Returns
        -------
        logits : np.ndarray
                Output of the network.
        """

        return self.batch_apply_fn({"params": params}, episode_features)

__call__(params, episode_features)

vmaped version of the model call function. Operates on a batch of episodes.

Parameters

parmas : dict Parameters of the model. episode_features: np.ndarray (n_steps, n_agents, observable_dimension) Features of the episode. This contains the features of all agents, for all time steps in the episode.

Returns

logits : np.ndarray Output of the network.

Source code in swarmrl/networks/flax_network.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def __call__(self, params: FrozenDict, episode_features):
    """
    vmaped version of the model call function.
    Operates on a batch of episodes.

    Parameters
    ----------
    parmas : dict
            Parameters of the model.
    episode_features: np.ndarray (n_steps, n_agents, observable_dimension)
            Features of the episode. This contains the features of all agents,
            for all time steps in the episode.


    Returns
    -------
    logits : np.ndarray
            Output of the network.
    """

    return self.batch_apply_fn({"params": params}, episode_features)

__init__(flax_model, input_shape, optimizer=None, exploration_policy=RandomExploration(probability=0.0), sampling_strategy=GumbelDistribution(), rng_key=None, deployment_mode=False)

Constructor for a Flax model.

Parameters

flax_model : nn.Module Flax model as a neural network. optimizer : Callable optimizer to use in the training. OpTax is used by default and cross-compatibility is not assured. input_shape : tuple Shape of the NN input. rng_key : int Key to seed the model with. Default is a randomly generated key but the parameter is here for testing purposes. deployment_mode : bool If true, the model is a shell for the network and nothing else. No training can be performed, this is only used in deployment.

Source code in swarmrl/networks/flax_network.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(
    self,
    flax_model: nn.Module,
    input_shape: tuple,
    optimizer: GradientTransformation = None,
    exploration_policy: ExplorationPolicy = RandomExploration(probability=0.0),
    sampling_strategy: SamplingStrategy = GumbelDistribution(),
    rng_key: int = None,
    deployment_mode: bool = False,
):
    """
    Constructor for a Flax model.

    Parameters
    ----------
    flax_model : nn.Module
            Flax model as a neural network.
    optimizer : Callable
            optimizer to use in the training. OpTax is used by default and
            cross-compatibility is not assured.
    input_shape : tuple
            Shape of the NN input.
    rng_key : int
            Key to seed the model with. Default is a randomly generated key but
            the parameter is here for testing purposes.
    deployment_mode : bool
            If true, the model is a shell for the network and nothing else. No
            training can be performed, this is only used in deployment.
    """
    if rng_key is None:
        rng_key = onp.random.randint(0, 1027465782564)
    self.sampling_strategy = sampling_strategy
    self.model = flax_model
    self.apply_fn = jax.jit(
        jax.vmap(self.model.apply, in_axes=(None, 0))
    )  # Map over agents
    self.batch_apply_fn = jax.jit(jax.vmap(self.apply_fn, in_axes=(None, 0)))
    self.input_shape = input_shape
    self.model_state = None

    if not deployment_mode:
        self.optimizer = optimizer
        self.exploration_policy = exploration_policy

        # initialize the model state
        init_rng = jax.random.PRNGKey(rng_key)
        _, subkey = jax.random.split(init_rng)
        self.model_state = self._create_train_state(subkey)

        self.epoch_count = 0

compute_action(observables)

Compute and action from the action space.

This method computes an action on all colloids of the relevant type.

Parameters

observables : List (n_agents, observable_dimension) Observable for each colloid for which the action should be computed.

Returns

tuple : (np.ndarray, np.ndarray) The first element is an array of indices corresponding to the action taken by the agent. The value is bounded between 0 and the number of output neurons. The second element is an array of the corresponding log_probs (i.e. the output of the network put through a softmax).

Source code in swarmrl/networks/flax_network.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def compute_action(self, observables: List):
    """
    Compute and action from the action space.

    This method computes an action on all colloids of the relevant type.

    Parameters
    ----------
    observables : List (n_agents, observable_dimension)
            Observable for each colloid for which the action should be computed.

    Returns
    -------
    tuple : (np.ndarray, np.ndarray)
            The first element is an array of indices corresponding to the action
            taken by the agent. The value is bounded between 0 and the number of
            output neurons. The second element is an array of the corresponding
            log_probs (i.e. the output of the network put through a softmax).
    """
    # Compute state
    try:
        logits, _ = self.apply_fn(
            {"params": self.model_state.params}, np.array(observables)
        )
    except AttributeError:  # We need this for loaded models.
        logits, _ = self.apply_fn(
            {"params": self.model_state["params"]}, np.array(observables)
        )
    logger.debug(f"{logits=}")  # (n_colloids, n_actions)

    # Compute the action
    indices = self.sampling_strategy(logits)
    # Add a small value to the log_probs to avoid log(0) errors.
    eps = 1e-8
    log_probs = np.log(jax.nn.softmax(logits) + eps)

    indices = self.exploration_policy(
        indices, logits.shape[-1], onp.random.randint(8759865)
    )
    return (
        indices,
        np.take_along_axis(log_probs, indices.reshape(-1, 1), axis=1).reshape(-1),
    )

export_model(filename='model', directory='Models')

Export the model state to a directory.

Parameters

filename : str (default=models) Name of the file the models are saved in. directory : str (default=Models) Directory in which to save the models. If the directory is not in the currently directory, it will be created.

Source code in swarmrl/networks/flax_network.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def export_model(self, filename: str = "model", directory: str = "Models"):
    """
    Export the model state to a directory.

    Parameters
    ----------
    filename : str (default=models)
            Name of the file the models are saved in.
    directory : str (default=Models)
            Directory in which to save the models. If the directory is not
            in the currently directory, it will be created.

    """
    model_params = self.model_state.params
    opt_state = self.model_state.opt_state
    opt_step = self.model_state.step
    epoch = self.epoch_count

    os.makedirs(directory, exist_ok=True)

    with open(directory + "/" + filename + ".pkl", "wb") as f:
        pickle.dump((model_params, opt_state, opt_step, epoch), f)

reinitialize_network()

Initialize the neural network.

Source code in swarmrl/networks/flax_network.py
123
124
125
126
127
128
129
130
def reinitialize_network(self):
    """
    Initialize the neural network.
    """
    rng_key = onp.random.randint(0, 1027465782564)
    init_rng = jax.random.PRNGKey(rng_key)
    _, subkey = jax.random.split(init_rng)
    self.model_state = self._create_train_state(subkey)

restore_model_state(filename, directory)

Restore the model state from a file.

Parameters

filename : str Name of the model state file directory : str Path to the model state file.

Returns

Updates the model state.

Source code in swarmrl/networks/flax_network.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def restore_model_state(self, filename, directory):
    """
    Restore the model state from a file.

    Parameters
    ----------
    filename : str
            Name of the model state file
    directory : str
            Path to the model state file.

    Returns
    -------
    Updates the model state.
    """

    with open(directory + "/" + filename + ".pkl", "rb") as f:
        model_params, opt_state, opt_step, epoch = pickle.load(f)

    self.model_state = self.model_state.replace(
        params=model_params, opt_state=opt_state, step=opt_step
    )
    self.epoch_count = epoch

update_model(grads)

Train the model.

See the parent class for a full doc-string.

Source code in swarmrl/networks/flax_network.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def update_model(self, grads):
    """
    Train the model.

    See the parent class for a full doc-string.
    """
    # Logging for grads and pre-train model state
    logger.debug(f"{grads=}")
    logger.debug(f"{self.model_state=}")

    if isinstance(self.optimizer, dict):
        pass

    else:
        self.model_state = self.model_state.apply_gradients(grads=grads)

    # Logging for post-train model state
    logger.debug(f"{self.model_state=}")

    self.epoch_count += 1