Skip to content

swarmrl.components.colloid Module API Reference

Data class for the colloid agent.

Colloid dataclass

Wrapper class for a colloid object.

Source code in swarmrl/components/colloid.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class Colloid:
    """
    Wrapper class for a colloid object.
    """

    pos: np.ndarray
    director: np.ndarray
    id: int
    velocity: np.ndarray = None
    type: int = 0

    def __repr__(self):
        """
        Return a string representation of the colloid.
        """
        return (
            f"Colloid(pos={self.pos}, director={self.director}, id={self.id},"
            f" velocity={self.velocity}, type={self.type})"
        )

    def __eq__(self, other):
        return self.id == other.id

    def tree_flatten(self):
        """
        Flatten the PyTree.
        """
        children = (self.pos, self.director, self.id, self.velocity, self.type)
        aux_data = None
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """
        Unflatten the PyTree.
        """
        return cls(*children)

__repr__()

Return a string representation of the colloid.

Source code in swarmrl/components/colloid.py
24
25
26
27
28
29
30
31
def __repr__(self):
    """
    Return a string representation of the colloid.
    """
    return (
        f"Colloid(pos={self.pos}, director={self.director}, id={self.id},"
        f" velocity={self.velocity}, type={self.type})"
    )

tree_flatten()

Flatten the PyTree.

Source code in swarmrl/components/colloid.py
36
37
38
39
40
41
42
def tree_flatten(self):
    """
    Flatten the PyTree.
    """
    children = (self.pos, self.director, self.id, self.velocity, self.type)
    aux_data = None
    return (children, aux_data)

tree_unflatten(aux_data, children) classmethod

Unflatten the PyTree.

Source code in swarmrl/components/colloid.py
44
45
46
47
48
49
@classmethod
def tree_unflatten(cls, aux_data, children):
    """
    Unflatten the PyTree.
    """
    return cls(*children)