Skip to content

swarmrl.utils.colloid_utils Module API Reference

Various functions for operating on colloids.

compute_distance_matrix(set_a, set_b)

Compute a distance matrix between two sets.

Helper function for computing the distance sets of colloids. This is not a commutative operation, if you swap a for b you will recieve a different matrix shape.

Parameters

set_a : jnp.ndarray First set of points. set_b : jnp.ndarray Second set of points.

Source code in swarmrl/utils/colloid_utils.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@jax.jit
def compute_distance_matrix(set_a, set_b):
    """
    Compute a distance matrix between two sets.

    Helper function for computing the distance sets of
    colloids. This is not a commutative operation, if you
    swap a for b you will recieve a different matrix shape.

    Parameters
    ----------
    set_a : jnp.ndarray
        First set of points.
    set_b : jnp.ndarray
        Second set of points.
    """

    def _sub_compute(a, b):
        return b - a

    distance_fn = jax.vmap(_sub_compute, in_axes=(0, None))

    return distance_fn(set_a, set_b)

compute_forces(r)

Compute the energy between two colloids.

This uses a WCA potential to compute a relative force between two colloids. It is not physical. The method itself implements an energy computation which then uses Jax to compute the gradient of the energy with respect to the distance between the colloids.

Parameters

r : jnp.ndarray (dimension, ) Distance between the two colloids.

Source code in swarmrl/utils/colloid_utils.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@jax.jit
def compute_forces(r):
    """
    Compute the energy between two colloids.

    This uses a WCA potential to compute a relative force between
    two colloids. It is not physical.
    The method itself implements an energy computation which then uses
    Jax to compute the gradient of the energy with respect to the
    distance between the colloids.

    Parameters
    ----------
    r : jnp.ndarray (dimension, )
        Distance between the two colloids.
    """

    def _sub_compute(r):
        return 1 / jnp.linalg.norm(r) ** 12

    force_fn = jax.grad(_sub_compute)

    return force_fn(r)

compute_torque(force, direction)

Compute the torque on a rod.

Parameters

Source code in swarmrl/utils/colloid_utils.py
64
65
66
67
68
69
70
71
72
73
@jax.jit
def compute_torque(force, direction):
    """
    Compute the torque on a rod.

    Parameters
    ----------

    """
    return jnp.cross(direction, force)

compute_torque_partition_on_rod(colloid_positions, rod_positions, rod_directions)

Compute the torque on a rod using a WCA potential.

Source code in swarmrl/utils/colloid_utils.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@jax.jit
def compute_torque_partition_on_rod(colloid_positions, rod_positions, rod_directions):
    """
    Compute the torque on a rod using a WCA potential.
    """
    # (n_colloids, rod_particles, 3)
    distance_matrix = compute_distance_matrix(colloid_positions, rod_positions)
    distance_matrix = distance_matrix[:, :, :2]

    # Force on the rod
    rod_map_fn = jax.vmap(compute_forces, in_axes=(0,))  # map over rod particles
    colloid_map_fn = jax.vmap(rod_map_fn, in_axes=(0,))  # map over colloids

    # (n_colloids, rod_particles, 3)
    forces = colloid_map_fn(distance_matrix)

    # Compute torques
    colloid_rod_map = jax.vmap(compute_torque, in_axes=(0, 0))
    colloid_only_map = jax.vmap(colloid_rod_map, in_axes=(0, None))

    torques = colloid_only_map(forces, rod_directions)
    net_rod_torque = torques.sum(axis=1)
    torque_magnitude = jnp.linalg.norm(net_rod_torque, axis=-1) + 1e-8
    normalization_factors = torque_magnitude.sum()
    torque_partition = torque_magnitude / normalization_factors

    return torque_partition

get_colloid_indices(colloids, p_type)

Get the indices of the colloids in the observable of a specific type.

Parameters

colloids : List[Colloid] List of colloids from which to get the indices. p_type : int (default=None) Type of the colloids to get the indices for. If None, the particle_type attribute of the class is used.

Returns

indices : List[int] List of indices for the colloids of a particular type.

Source code in swarmrl/utils/colloid_utils.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def get_colloid_indices(colloids: List["Colloid"], p_type: int) -> List[int]:
    """
    Get the indices of the colloids in the observable of a specific type.

    Parameters
    ----------
    colloids : List[Colloid]
            List of colloids from which to get the indices.
    p_type : int (default=None)
            Type of the colloids to get the indices for. If None, the
            particle_type attribute of the class is used.


    Returns
    -------
    indices : List[int]
            List of indices for the colloids of a particular type.
    """
    indices = []
    for i, colloid in enumerate(colloids):
        if colloid.type == p_type:
            indices.append(i)

    return indices