Skip to content

swarmrl.components.swarm Module API Reference

Class for the Swarm Pytree Agent

Swarm dataclass

Wrapper class for a colloid object.

Source code in swarmrl/components/swarm.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
@register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class Swarm:
    """
    Wrapper class for a colloid object.
    """

    # Colloid attributes
    pos: np.ndarray
    director: np.ndarray
    id: int
    velocity: np.ndarray = None
    type: int = 0

    # Swarm attributes
    type_indices: dict = None

    def __repr__(self) -> str:
        """
        Return a string representation of the colloid.
        """
        return (
            f"Colloid(pos={self.pos}, director={self.director}, id={self.id},"
            f" velocity={self.velocity}, type={self.type})"
        )

    def __eq__(self, other):
        return self.id == other.id

    def tree_flatten(self) -> tuple:
        """
        Flatten the PyTree.
        """
        children = (
            self.pos,
            self.director,
            self.id,
            self.velocity,
            self.type,
            self.type_indices,
        )
        aux_data = None
        return (children, aux_data)

    def get_species_swarm(self, species: int) -> Swarm:
        """
        Get a swarm of one species.

        Parameters
        ----------
        species : int
            Species index.

        Returns
        -------
        partitioned_swarm : Swarm
            Swarm of one species.
        """
        indices = self.type_indices[species]
        return Swarm(
            pos=np.take(self.pos, indices, axis=0),
            director=np.take(self.director, indices, axis=0),
            id=np.take(self.id, indices, axis=0),
            velocity=np.take(self.velocity, indices, axis=0),
            type=np.take(self.type, indices, axis=0),
            type_indices=None,
        )

    @classmethod
    def tree_unflatten(cls, aux_data, children) -> Swarm:
        """
        Unflatten the PyTree.

        This method is required by Pytrees in Jax.

        Parameters
        ----------
        aux_data : None
            Auxiliary data. Not used in this class.
        children : tuple
            Tuple of children to be unflattened.
        """
        return cls(*children)

__repr__()

Return a string representation of the colloid.

Source code in swarmrl/components/swarm.py
35
36
37
38
39
40
41
42
def __repr__(self) -> str:
    """
    Return a string representation of the colloid.
    """
    return (
        f"Colloid(pos={self.pos}, director={self.director}, id={self.id},"
        f" velocity={self.velocity}, type={self.type})"
    )

get_species_swarm(species)

Get a swarm of one species.

Parameters

species : int Species index.

Returns

partitioned_swarm : Swarm Swarm of one species.

Source code in swarmrl/components/swarm.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def get_species_swarm(self, species: int) -> Swarm:
    """
    Get a swarm of one species.

    Parameters
    ----------
    species : int
        Species index.

    Returns
    -------
    partitioned_swarm : Swarm
        Swarm of one species.
    """
    indices = self.type_indices[species]
    return Swarm(
        pos=np.take(self.pos, indices, axis=0),
        director=np.take(self.director, indices, axis=0),
        id=np.take(self.id, indices, axis=0),
        velocity=np.take(self.velocity, indices, axis=0),
        type=np.take(self.type, indices, axis=0),
        type_indices=None,
    )

tree_flatten()

Flatten the PyTree.

Source code in swarmrl/components/swarm.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def tree_flatten(self) -> tuple:
    """
    Flatten the PyTree.
    """
    children = (
        self.pos,
        self.director,
        self.id,
        self.velocity,
        self.type,
        self.type_indices,
    )
    aux_data = None
    return (children, aux_data)

tree_unflatten(aux_data, children) classmethod

Unflatten the PyTree.

This method is required by Pytrees in Jax.

Parameters

aux_data : None Auxiliary data. Not used in this class. children : tuple Tuple of children to be unflattened.

Source code in swarmrl/components/swarm.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
@classmethod
def tree_unflatten(cls, aux_data, children) -> Swarm:
    """
    Unflatten the PyTree.

    This method is required by Pytrees in Jax.

    Parameters
    ----------
    aux_data : None
        Auxiliary data. Not used in this class.
    children : tuple
        Tuple of children to be unflattened.
    """
    return cls(*children)

create_swarm(colloids)

Create a swarm from a list of colloid objects.

Parameters

colloid : List[Colloid] List of colloid objects.

Returns

Swarm Swarm object full of all colloids

Source code in swarmrl/components/swarm.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def create_swarm(colloids: List[Colloid]) -> Swarm:
    """
    Create a swarm from a list of colloid objects.

    Parameters
    ----------
    colloid : List[Colloid]
        List of colloid objects.

    Returns
    -------
    Swarm
        Swarm object full of all colloids
    """
    # standard colloid attributes
    pos = np.array([c.pos for c in colloids]).reshape(-1, colloids[0].pos.shape[0])
    director = np.array([c.director for c in colloids]).reshape(
        -1, colloids[0].director.shape[0]
    )
    id = np.array([c.id for c in colloids]).reshape(-1, 1)
    velocity = np.array([c.velocity for c in colloids]).reshape(
        -1, colloids[0].velocity.shape[0]
    )
    type = np.array([c.type for c in colloids]).reshape(-1, 1)

    # add species indices to the colloid types.
    type_indices = {}
    types = onp.unique(type)
    for t in types:
        type_indices[t] = np.array(get_colloid_indices(colloids, t))

    return Swarm(pos, director, id, velocity, type, type_indices)