Skip to content

API

Stable vs Experimental

This API reference lists both stable and experimental modules. Experimental APIs live under torchrir.experimental and may change without notice. Prefer top-level torchrir and documented submodules for stable use.

Modules

torchrir

torchrir

TorchRIR public API.

__all__ module-attribute

__all__ = ['DynamicScene', 'Room', 'Source', 'MicrophoneArray', 'Scene', 'StaticScene', 'RIRResult', 'load', 'save']

DynamicScene dataclass

Container for dynamic scene simulation inputs.

Examples:

scene = DynamicScene(room=room, sources=sources, mics=mics, src_traj=src_traj, mic_traj=mic_traj)
Source code in src/torchrir/models/scene.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@dataclass(frozen=True)
class DynamicScene:
    """Container for dynamic scene simulation inputs.

    Examples:
        ```python
        scene = DynamicScene(room=room, sources=sources, mics=mics, src_traj=src_traj, mic_traj=mic_traj)
        ```
    """

    room: Room
    sources: Source
    mics: MicrophoneArray
    src_traj: Tensor
    mic_traj: Tensor

    def __post_init__(self) -> None:
        src_traj = as_tensor(self.src_traj)
        mic_traj = as_tensor(self.mic_traj)
        object.__setattr__(self, "src_traj", src_traj)
        object.__setattr__(self, "mic_traj", mic_traj)
        self._validate_internal()

    def is_dynamic(self) -> bool:
        return True

    def validate(self) -> None:
        self._validate_internal()

    def _validate_internal(self) -> None:
        _validate_scene_entities(self.room, self.sources, self.mics)
        dim = int(self.room.size.numel())
        n_src = int(self.sources.positions.shape[0])
        n_mic = int(self.mics.positions.shape[0])
        t_src = _validate_traj(self.src_traj, n_src, dim, "src_traj")
        t_mic = _validate_traj(self.mic_traj, n_mic, dim, "mic_traj")
        if t_src != t_mic:
            raise ValueError("src_traj and mic_traj must have matching time steps")

MicrophoneArray dataclass

Microphone array container.

Examples:

mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
Source code in src/torchrir/models/room.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
@dataclass(frozen=True)
class MicrophoneArray:
    """Microphone array container.

    Examples:
        ```python
        mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
        ```
    """

    positions: Tensor
    orientation: Optional[Tensor] = None

    def __post_init__(self) -> None:
        pos = _normalize_entity_positions(self.positions, name="mic")
        object.__setattr__(self, "positions", pos)
        ori = _normalize_entity_orientation(
            self.orientation, n_entities=pos.shape[0], dim=pos.shape[1], name="mic"
        )
        if ori is not None:
            object.__setattr__(self, "orientation", ori)

    def replace(self, **kwargs) -> "MicrophoneArray":
        """Return a new MicrophoneArray with updated fields."""
        return replace(self, **kwargs)

    @classmethod
    def from_positions(
        cls,
        positions: Sequence[Sequence[float]] | Tensor,
        *,
        orientation: Optional[Sequence[float] | Tensor] = None,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> "MicrophoneArray":
        """Convert positions/orientation to tensors and build a MicrophoneArray."""
        pos = as_tensor(positions, device=device, dtype=dtype)
        ori = None
        if orientation is not None:
            ori = as_tensor(orientation, device=device, dtype=dtype)
        return cls(pos, ori)

from_positions classmethod

from_positions(positions, *, orientation=None, device=None, dtype=None)

Convert positions/orientation to tensors and build a MicrophoneArray.

Source code in src/torchrir/models/room.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
@classmethod
def from_positions(
    cls,
    positions: Sequence[Sequence[float]] | Tensor,
    *,
    orientation: Optional[Sequence[float] | Tensor] = None,
    device: Optional[torch.device | str] = None,
    dtype: Optional[torch.dtype] = None,
) -> "MicrophoneArray":
    """Convert positions/orientation to tensors and build a MicrophoneArray."""
    pos = as_tensor(positions, device=device, dtype=dtype)
    ori = None
    if orientation is not None:
        ori = as_tensor(orientation, device=device, dtype=dtype)
    return cls(pos, ori)

replace

replace(**kwargs)

Return a new MicrophoneArray with updated fields.

Source code in src/torchrir/models/room.py
154
155
156
def replace(self, **kwargs) -> "MicrophoneArray":
    """Return a new MicrophoneArray with updated fields."""
    return replace(self, **kwargs)

RIRResult dataclass

Container for RIRs with metadata.

Examples:

from torchrir.sim import ISMSimulator
result = ISMSimulator(max_order=6, tmax=0.3).simulate(scene, config)
rirs = result.rirs
Source code in src/torchrir/models/results.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
@dataclass(frozen=True)
class RIRResult:
    """Container for RIRs with metadata.

    Examples:
        ```python
        from torchrir.sim import ISMSimulator
        result = ISMSimulator(max_order=6, tmax=0.3).simulate(scene, config)
        rirs = result.rirs
        ```
    """

    rirs: Tensor
    scene: SceneLike
    config: "SimulationConfig"
    timestamps: Optional[Tensor] = None
    seed: Optional[int] = None

Room dataclass

Room geometry and acoustic parameters.

Examples:

room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
Source code in src/torchrir/models/room.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@dataclass(frozen=True)
class Room:
    """Room geometry and acoustic parameters.

    Examples:
        ```python
        room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
        ```
    """

    size: Tensor
    fs: float
    c: float = 343.0
    beta: Optional[Tensor] = None
    t60: Optional[float] = None

    def __post_init__(self) -> None:
        """Validate room size and reflection parameters."""
        size = ensure_dim(self.size)
        if not torch.all(torch.isfinite(size)):
            raise ValueError("room size must contain finite values")
        if torch.any(size <= 0):
            raise ValueError("room size must be strictly positive")
        object.__setattr__(self, "size", size)
        if self.fs <= 0:
            raise ValueError("fs must be positive")
        if self.c <= 0:
            raise ValueError("c must be positive")
        if self.beta is not None and self.t60 is not None:
            raise ValueError("beta and t60 are mutually exclusive")
        if self.t60 is not None and self.t60 <= 0:
            raise ValueError("t60 must be positive")
        if self.beta is not None:
            beta = as_tensor(self.beta, dtype=size.dtype).view(-1)
            expected = 4 if size.numel() == 2 else 6
            if beta.numel() != expected:
                raise ValueError(
                    f"beta must have {expected} elements for {size.numel()}D rooms"
                )
            if not torch.all(torch.isfinite(beta)):
                raise ValueError("beta must contain finite values")
            if torch.any(beta < 0) or torch.any(beta > 1):
                raise ValueError("beta values must be in [0, 1]")
            object.__setattr__(self, "beta", beta)

    def replace(self, **kwargs) -> "Room":
        """Return a new Room with updated fields."""
        return replace(self, **kwargs)

    @staticmethod
    def shoebox(
        size: Sequence[float] | Tensor,
        *,
        fs: float,
        c: float = 343.0,
        beta: Optional[Sequence[float] | Tensor] = None,
        t60: Optional[float] = None,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> "Room":
        """Create a rectangular (shoebox) room.

        Examples:
            ```python
            room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
            ```
        """
        size_t = as_tensor(size, device=device, dtype=dtype)
        size_t = ensure_dim(size_t)
        beta_t = None
        if beta is not None:
            beta_t = as_tensor(beta, device=device, dtype=dtype)
        return Room(size=size_t, fs=fs, c=c, beta=beta_t, t60=t60)

__post_init__

__post_init__()

Validate room size and reflection parameters.

Source code in src/torchrir/models/room.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def __post_init__(self) -> None:
    """Validate room size and reflection parameters."""
    size = ensure_dim(self.size)
    if not torch.all(torch.isfinite(size)):
        raise ValueError("room size must contain finite values")
    if torch.any(size <= 0):
        raise ValueError("room size must be strictly positive")
    object.__setattr__(self, "size", size)
    if self.fs <= 0:
        raise ValueError("fs must be positive")
    if self.c <= 0:
        raise ValueError("c must be positive")
    if self.beta is not None and self.t60 is not None:
        raise ValueError("beta and t60 are mutually exclusive")
    if self.t60 is not None and self.t60 <= 0:
        raise ValueError("t60 must be positive")
    if self.beta is not None:
        beta = as_tensor(self.beta, dtype=size.dtype).view(-1)
        expected = 4 if size.numel() == 2 else 6
        if beta.numel() != expected:
            raise ValueError(
                f"beta must have {expected} elements for {size.numel()}D rooms"
            )
        if not torch.all(torch.isfinite(beta)):
            raise ValueError("beta must contain finite values")
        if torch.any(beta < 0) or torch.any(beta > 1):
            raise ValueError("beta values must be in [0, 1]")
        object.__setattr__(self, "beta", beta)

replace

replace(**kwargs)

Return a new Room with updated fields.

Source code in src/torchrir/models/room.py
59
60
61
def replace(self, **kwargs) -> "Room":
    """Return a new Room with updated fields."""
    return replace(self, **kwargs)

shoebox staticmethod

shoebox(size, *, fs, c=343.0, beta=None, t60=None, device=None, dtype=None)

Create a rectangular (shoebox) room.

Examples:

room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
Source code in src/torchrir/models/room.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@staticmethod
def shoebox(
    size: Sequence[float] | Tensor,
    *,
    fs: float,
    c: float = 343.0,
    beta: Optional[Sequence[float] | Tensor] = None,
    t60: Optional[float] = None,
    device: Optional[torch.device | str] = None,
    dtype: Optional[torch.dtype] = None,
) -> "Room":
    """Create a rectangular (shoebox) room.

    Examples:
        ```python
        room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
        ```
    """
    size_t = as_tensor(size, device=device, dtype=dtype)
    size_t = ensure_dim(size_t)
    beta_t = None
    if beta is not None:
        beta_t = as_tensor(beta, device=device, dtype=dtype)
    return Room(size=size_t, fs=fs, c=c, beta=beta_t, t60=t60)

Scene dataclass

Deprecated scene wrapper.

Scene is kept for backward compatibility. Prefer StaticScene and DynamicScene to avoid ambiguous states.

Source code in src/torchrir/models/scene.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
@dataclass(frozen=True)
class Scene:
    """Deprecated scene wrapper.

    `Scene` is kept for backward compatibility. Prefer `StaticScene` and
    `DynamicScene` to avoid ambiguous states.
    """

    room: Room
    sources: Source
    mics: MicrophoneArray
    src_traj: Optional[Tensor] = None
    mic_traj: Optional[Tensor] = None

    def __post_init__(self) -> None:
        warnings.warn(
            "Scene is deprecated and will be removed in a future release. "
            "Use StaticScene or DynamicScene.",
            DeprecationWarning,
            stacklevel=2,
        )
        self._validate_internal()

    def _validate_internal(self) -> None:
        _validate_scene_entities(self.room, self.sources, self.mics)
        has_src = self.src_traj is not None
        has_mic = self.mic_traj is not None
        if has_src != has_mic:
            raise ValueError(
                "Scene requires both src_traj and mic_traj for dynamic scenes. "
                "Use StaticScene for static inputs."
            )
        if has_src and has_mic:
            assert self.src_traj is not None
            assert self.mic_traj is not None
            dim = int(self.room.size.numel())
            n_src = int(self.sources.positions.shape[0])
            n_mic = int(self.mics.positions.shape[0])
            t_src = _validate_traj(self.src_traj, n_src, dim, "src_traj")
            t_mic = _validate_traj(self.mic_traj, n_mic, dim, "mic_traj")
            if t_src != t_mic:
                raise ValueError("src_traj and mic_traj must have matching time steps")

    def is_dynamic(self) -> bool:
        return self.src_traj is not None and self.mic_traj is not None

    def validate(self) -> None:
        self._validate_internal()

    def to_static_scene(self) -> StaticScene:
        if self.is_dynamic():
            raise ValueError("dynamic Scene cannot be converted to StaticScene")
        return StaticScene(room=self.room, sources=self.sources, mics=self.mics)

    def to_dynamic_scene(self) -> DynamicScene:
        if not self.is_dynamic() or self.src_traj is None or self.mic_traj is None:
            raise ValueError("static Scene cannot be converted to DynamicScene")
        return DynamicScene(
            room=self.room,
            sources=self.sources,
            mics=self.mics,
            src_traj=self.src_traj,
            mic_traj=self.mic_traj,
        )

Source dataclass

Source container with positions and optional orientation.

Examples:

sources = Source.from_positions([[1.0, 2.0, 1.5]])
Source code in src/torchrir/models/room.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
@dataclass(frozen=True)
class Source:
    """Source container with positions and optional orientation.

    Examples:
        ```python
        sources = Source.from_positions([[1.0, 2.0, 1.5]])
        ```
    """

    positions: Tensor
    orientation: Optional[Tensor] = None

    def __post_init__(self) -> None:
        pos = _normalize_entity_positions(self.positions, name="source")
        object.__setattr__(self, "positions", pos)
        ori = _normalize_entity_orientation(
            self.orientation, n_entities=pos.shape[0], dim=pos.shape[1], name="source"
        )
        if ori is not None:
            object.__setattr__(self, "orientation", ori)

    def replace(self, **kwargs) -> "Source":
        """Return a new Source with updated fields."""
        return replace(self, **kwargs)

    @classmethod
    def from_positions(
        cls,
        positions: Sequence[Sequence[float]] | Tensor,
        *,
        orientation: Optional[Sequence[float] | Tensor] = None,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> "Source":
        """Convert positions/orientation to tensors and build a Source."""
        pos = as_tensor(positions, device=device, dtype=dtype)
        ori = None
        if orientation is not None:
            ori = as_tensor(orientation, device=device, dtype=dtype)
        return cls(pos, ori)

from_positions classmethod

from_positions(positions, *, orientation=None, device=None, dtype=None)

Convert positions/orientation to tensors and build a Source.

Source code in src/torchrir/models/room.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
@classmethod
def from_positions(
    cls,
    positions: Sequence[Sequence[float]] | Tensor,
    *,
    orientation: Optional[Sequence[float] | Tensor] = None,
    device: Optional[torch.device | str] = None,
    dtype: Optional[torch.dtype] = None,
) -> "Source":
    """Convert positions/orientation to tensors and build a Source."""
    pos = as_tensor(positions, device=device, dtype=dtype)
    ori = None
    if orientation is not None:
        ori = as_tensor(orientation, device=device, dtype=dtype)
    return cls(pos, ori)

replace

replace(**kwargs)

Return a new Source with updated fields.

Source code in src/torchrir/models/room.py
111
112
113
def replace(self, **kwargs) -> "Source":
    """Return a new Source with updated fields."""
    return replace(self, **kwargs)

StaticScene dataclass

Container for static scene simulation inputs.

Examples:

scene = StaticScene(room=room, sources=sources, mics=mics)
Source code in src/torchrir/models/scene.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@dataclass(frozen=True)
class StaticScene:
    """Container for static scene simulation inputs.

    Examples:
        ```python
        scene = StaticScene(room=room, sources=sources, mics=mics)
        ```
    """

    room: Room
    sources: Source
    mics: MicrophoneArray

    def __post_init__(self) -> None:
        _validate_scene_entities(self.room, self.sources, self.mics)

    def is_dynamic(self) -> bool:
        return False

    def validate(self) -> None:
        _validate_scene_entities(self.room, self.sources, self.mics)

load

load(path, *, backend=None, format=None)

Deprecated top-level loader. Use torchrir.io.load_wav/torchrir.io.load_audio.

Source code in src/torchrir/__init__.py
21
22
23
24
25
26
27
28
29
def load(path: Path, *, backend: str | None = None, format: str | None = None) -> Tuple[Tensor, int]:
    """Deprecated top-level loader. Use `torchrir.io.load_wav`/`torchrir.io.load_audio`."""

    warnings.warn(
        "torchrir.load is deprecated. Use torchrir.io.load_wav or torchrir.io.load_audio.",
        DeprecationWarning,
        stacklevel=2,
    )
    return io.load_wav(path, backend=backend, format=format)

save

save(path, audio, sample_rate, *, backend=None, format=None, normalize=True, peak=1.0, subtype=None)

Deprecated top-level saver. Use torchrir.io.save_wav/torchrir.io.save_audio.

Source code in src/torchrir/__init__.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def save(
    path: Path,
    audio: Tensor,
    sample_rate: int,
    *,
    backend: str | None = None,
    format: str | None = None,
    normalize: bool = True,
    peak: float = 1.0,
    subtype: str | None = None,
) -> None:
    """Deprecated top-level saver. Use `torchrir.io.save_wav`/`torchrir.io.save_audio`."""

    warnings.warn(
        "torchrir.save is deprecated. Use torchrir.io.save_wav or torchrir.io.save_audio.",
        DeprecationWarning,
        stacklevel=2,
    )
    io.save_wav(
        path,
        audio,
        sample_rate,
        backend=backend,
        format=format,
        normalize=normalize,
        peak=peak,
        subtype=subtype,
    )

torchrir.sim

torchrir.sim

Simulation engines and configuration for RIR generation.

Includes the ISM implementation (in torchrir.sim.ism), directivity helpers, and simulator interfaces for ISM plus placeholder ray-tracing/FDTD backends.

__all__ module-attribute

__all__ = ['ISMSimulator', 'RIRSimulator', 'directivity_gain', 'simulate_dynamic_rir', 'simulate_rir', 'split_directivity']

ISMSimulator dataclass

ISM-based simulator using the current core implementation.

Examples:

result = ISMSimulator(max_order=6, tmax=0.3).simulate(scene, config)
Source code in src/torchrir/sim/simulators.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
@dataclass(frozen=True)
class ISMSimulator:
    """ISM-based simulator using the current core implementation.

    Examples:
        ```python
        result = ISMSimulator(max_order=6, tmax=0.3).simulate(scene, config)
        ```
    """

    max_order: int
    tmax: float | None = None
    nsample: int | None = None
    directivity: str | tuple[str, str] | None = "omni"
    nb_img: torch.Tensor | tuple[int, ...] | None = None
    device: torch.device | str | None = None
    dtype: torch.dtype | None = None

    def __post_init__(self) -> None:
        if self.max_order < 0:
            raise ValueError("max_order must be non-negative")
        if self.tmax is None and self.nsample is None:
            raise ValueError("tmax or nsample must be provided")
        if self.tmax is not None and self.tmax <= 0:
            raise ValueError("tmax must be positive")
        if self.nsample is not None and self.nsample <= 0:
            raise ValueError("nsample must be positive")

    def simulate(
        self, scene: SceneLike, config: SimulationConfig | None = None
    ) -> RIRResult:
        normalized_scene = _normalize_scene(scene)
        normalized_scene.validate()
        cfg = config or default_config()
        _ensure_no_conflict(
            field="max_order",
            simulator_value=self.max_order,
            config_value=cfg.max_order,
        )
        _ensure_no_conflict(
            field="tmax",
            simulator_value=self.tmax,
            config_value=cfg.tmax,
        )
        if isinstance(normalized_scene, DynamicScene):
            rirs = simulate_dynamic_rir(
                room=normalized_scene.room,
                src_traj=normalized_scene.src_traj,
                mic_traj=normalized_scene.mic_traj,
                max_order=self.max_order,
                nb_img=self.nb_img,
                nsample=self.nsample,
                tmax=self.tmax,
                directivity=self.directivity,
                config=cfg,
                device=self.device,
                dtype=self.dtype,
            )
        else:
            rirs = simulate_rir(
                room=normalized_scene.room,
                sources=normalized_scene.sources,
                mics=normalized_scene.mics,
                max_order=self.max_order,
                nb_img=self.nb_img,
                nsample=self.nsample,
                tmax=self.tmax,
                directivity=self.directivity,
                config=cfg,
                device=self.device,
                dtype=self.dtype,
            )
        return RIRResult(rirs=rirs, scene=normalized_scene, config=cfg, seed=cfg.seed)

RIRSimulator

Bases: Protocol

Strategy interface for RIR simulation backends.

Source code in src/torchrir/sim/simulators.py
16
17
18
19
20
21
22
class RIRSimulator(Protocol):
    """Strategy interface for RIR simulation backends."""

    def simulate(
        self, scene: SceneLike, config: SimulationConfig | None = None
    ) -> RIRResult:
        """Run a simulation and return the result."""

simulate

simulate(scene, config=None)

Run a simulation and return the result.

Source code in src/torchrir/sim/simulators.py
19
20
21
22
def simulate(
    self, scene: SceneLike, config: SimulationConfig | None = None
) -> RIRResult:
    """Run a simulation and return the result."""

directivity_gain

directivity_gain(pattern, cos_theta)

Compute directivity gain for a pattern given cos(theta).

Source code in src/torchrir/sim/directivity.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def directivity_gain(pattern: str, cos_theta: Tensor) -> Tensor:
    """Compute directivity gain for a pattern given cos(theta)."""
    pattern = pattern.lower()
    if pattern in ("omni", "omnidirectional"):
        return torch.ones_like(cos_theta)
    if pattern in ("homni", "halfomni", "half-omni"):
        return (cos_theta > 0).to(cos_theta.dtype)
    if pattern in ("subcardioid", "subcard"):
        return 0.75 + 0.25 * cos_theta
    if pattern in ("cardioid", "card"):
        return 0.5 + 0.5 * cos_theta
    if pattern in ("hypercardioid", "hypcard"):
        return 0.25 + 0.75 * cos_theta
    if pattern in ("bidir", "bidirectional", "figure8", "figure-8"):
        return cos_theta
    raise ValueError(f"unsupported directivity pattern: {pattern}")

simulate_dynamic_rir

simulate_dynamic_rir(*, room, src_traj, mic_traj, max_order, nb_img=None, nsample=None, tmax=None, directivity='omni', orientation=None, config=None, device=None, dtype=None)

Simulate time-varying RIRs for source/mic trajectories.

Source code in src/torchrir/sim/ism/api.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def simulate_dynamic_rir(
    *,
    room: Room,
    src_traj: Tensor,
    mic_traj: Tensor,
    max_order: int | None,
    nb_img: Optional[Tensor | Tuple[int, ...]] = None,
    nsample: Optional[int] = None,
    tmax: Optional[float] = None,
    directivity: str | tuple[str, str] | None = "omni",
    orientation: Optional[Tensor | tuple[Tensor, Tensor]] = None,
    config: Optional[SimulationConfig] = None,
    device: Optional[torch.device | str] = None,
    dtype: Optional[torch.dtype] = None,
) -> Tensor:
    """Simulate time-varying RIRs for source/mic trajectories."""
    cfg, device, max_order, tmax, directivity = _resolve_config(
        config=config,
        device=device,
        max_order=max_order,
        tmax=tmax,
        directivity=directivity,
    )
    nsample = _validate_dynamic_args(
        room=room, nsample=nsample, tmax=tmax, max_order=max_order
    )
    (
        src_traj,
        mic_traj,
        src_ori,
        mic_ori,
        room_size,
        dim,
        device,
        dtype,
    ) = _prepare_dynamic_tensors(
        room=room,
        src_traj=src_traj,
        mic_traj=mic_traj,
        orientation=orientation,
        device=device,
        dtype=dtype,
    )
    _validate_traj_shapes(src_traj, mic_traj, dim)

    beta = _resolve_beta(room, room_size, device=device, dtype=dtype)
    beta = _validate_beta(beta, dim)
    n_vec = _image_source_indices(max_order, dim, device=device, nb_img=nb_img)
    refl = _reflection_coefficients(n_vec, beta)

    src_pattern, mic_pattern = split_directivity(directivity)
    mic_dir = None
    if mic_pattern != "omni":
        if mic_ori is None:
            raise ValueError("mic orientation required for non-omni directivity")
        mic_dir = orientation_to_unit(mic_ori, dim)

    n_src = src_traj.shape[1]
    n_mic = mic_traj.shape[1]
    rirs = torch.zeros(
        (src_traj.shape[0], n_src, n_mic, nsample), device=device, dtype=dtype
    )
    fdl = cfg.frac_delay_length
    fdl2 = (fdl - 1) // 2
    img_chunk = cfg.image_chunk_size
    if img_chunk <= 0:
        img_chunk = n_vec.shape[0]

    src_dirs = None
    if src_pattern != "omni":
        if src_ori is None:
            raise ValueError("source orientation required for non-omni directivity")
        src_dirs = orientation_to_unit(src_ori, dim)
        if src_dirs.ndim == 1:
            src_dirs = src_dirs.unsqueeze(0).repeat(n_src, 1)
        if src_dirs.ndim != 2 or src_dirs.shape[0] != n_src:
            raise ValueError("source orientation must match number of sources")

    for start in range(0, n_vec.shape[0], img_chunk):
        end = min(start + img_chunk, n_vec.shape[0])
        n_vec_chunk = n_vec[start:end]
        refl_chunk = refl[start:end]
        sample_chunk, attenuation_chunk = _compute_image_contributions_time_batch(
            src_traj,
            mic_traj,
            room_size,
            n_vec_chunk,
            refl_chunk,
            room,
            fdl2,
            src_pattern=src_pattern,
            mic_pattern=mic_pattern,
            src_dirs=src_dirs,
            mic_dir=mic_dir,
        )
        t_steps = src_traj.shape[0]
        sample_flat = sample_chunk.reshape(t_steps * n_src, n_mic, -1)
        attenuation_flat = attenuation_chunk.reshape(t_steps * n_src, n_mic, -1)
        rir_flat = rirs.view(t_steps * n_src, n_mic, nsample)
        _accumulate_rir_batch(rir_flat, sample_flat, attenuation_flat, cfg)

    rirs = apply_rir_hpf(rirs, room.fs, cfg)
    return rirs

simulate_rir

simulate_rir(*, room, sources, mics, max_order, nb_img=None, nsample=None, tmax=None, tdiff=None, directivity='omni', orientation=None, config=None, device=None, dtype=None)

Simulate a static RIR using the image source method.

Source code in src/torchrir/sim/ism/api.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def simulate_rir(
    *,
    room: Room,
    sources: Source | Tensor,
    mics: MicrophoneArray | Tensor,
    max_order: int | None,
    nb_img: Optional[Tensor | Tuple[int, ...]] = None,
    nsample: Optional[int] = None,
    tmax: Optional[float] = None,
    tdiff: Optional[float] = None,
    directivity: str | tuple[str, str] | None = "omni",
    orientation: Optional[Tensor | tuple[Tensor, Tensor]] = None,
    config: Optional[SimulationConfig] = None,
    device: Optional[torch.device | str] = None,
    dtype: Optional[torch.dtype] = None,
) -> Tensor:
    """Simulate a static RIR using the image source method."""
    cfg, device, max_order, tmax, directivity = _resolve_config(
        config=config,
        device=device,
        max_order=max_order,
        tmax=tmax,
        directivity=directivity,
    )
    nsample = _validate_static_args(
        room=room, nsample=nsample, tmax=tmax, max_order=max_order
    )
    (
        src_pos,
        mic_pos,
        src_ori,
        mic_ori,
        room_size,
        dim,
        device,
        dtype,
    ) = _prepare_static_tensors(
        room=room,
        sources=sources,
        mics=mics,
        orientation=orientation,
        device=device,
        dtype=dtype,
    )
    _validate_pos_shapes(src_pos, mic_pos, dim)

    beta = _resolve_beta(room, room_size, device=device, dtype=dtype)
    beta = _validate_beta(beta, dim)

    n_vec = _image_source_indices(max_order, dim, device=device, nb_img=nb_img)
    refl = _reflection_coefficients(n_vec, beta)

    src_pattern, mic_pattern = split_directivity(directivity)
    mic_dir = None
    if mic_pattern != "omni":
        if mic_ori is None:
            raise ValueError("mic orientation required for non-omni directivity")
        mic_dir = orientation_to_unit(mic_ori, dim)

    n_src = src_pos.shape[0]
    n_mic = mic_pos.shape[0]
    rir = torch.zeros((n_src, n_mic, nsample), device=device, dtype=dtype)
    fdl = cfg.frac_delay_length
    fdl2 = (fdl - 1) // 2
    img_chunk = cfg.image_chunk_size
    if img_chunk <= 0:
        img_chunk = n_vec.shape[0]

    src_dirs = None
    if src_pattern != "omni":
        if src_ori is None:
            raise ValueError("source orientation required for non-omni directivity")
        src_dirs = orientation_to_unit(src_ori, dim)
        if src_dirs.ndim == 1:
            src_dirs = src_dirs.unsqueeze(0).repeat(n_src, 1)
        if src_dirs.ndim != 2 or src_dirs.shape[0] != n_src:
            raise ValueError("source orientation must match number of sources")

    for start in range(0, n_vec.shape[0], img_chunk):
        end = min(start + img_chunk, n_vec.shape[0])
        n_vec_chunk = n_vec[start:end]
        refl_chunk = refl[start:end]
        sample_chunk, attenuation_chunk = _compute_image_contributions_batch(
            src_pos,
            mic_pos,
            room_size,
            n_vec_chunk,
            refl_chunk,
            room,
            fdl2,
            src_pattern=src_pattern,
            mic_pattern=mic_pattern,
            src_dirs=src_dirs,
            mic_dir=mic_dir,
        )
        _accumulate_rir_batch(rir, sample_chunk, attenuation_chunk, cfg)

    if tdiff is not None and tmax is not None and tdiff < tmax:
        rir = _apply_diffuse_tail(rir, room, beta, tdiff, tmax, seed=cfg.seed)
    rir = apply_rir_hpf(rir, room.fs, cfg)
    return rir

split_directivity

split_directivity(directivity)

Normalize directivity specification into (source, mic).

Source code in src/torchrir/sim/directivity.py
27
28
29
30
31
32
33
def split_directivity(directivity: str | tuple[str, str]) -> tuple[str, str]:
    """Normalize directivity specification into (source, mic)."""
    if isinstance(directivity, (list, tuple)):
        if len(directivity) != 2:
            raise ValueError("directivity tuple must have length 2")
        return directivity[0], directivity[1]
    return directivity, directivity

torchrir.signal

torchrir.signal

Signal processing utilities for static and dynamic RIR convolution.

__all__ module-attribute

__all__ = ['DynamicConvolver', 'convolve_rir', 'fft_convolve']

DynamicConvolver dataclass

Convolver for time-varying RIRs.

Examples:

convolver = DynamicConvolver(mode="trajectory")
y = convolver.convolve(signal, rirs)
Source code in src/torchrir/signal/dynamic.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@dataclass(frozen=True)
class DynamicConvolver:
    """Convolver for time-varying RIRs.

    Examples:
        ```python
        convolver = DynamicConvolver(mode="trajectory")
        y = convolver.convolve(signal, rirs)
        ```
    """

    mode: str = "trajectory"
    hop: Optional[int] = None
    timestamps: Optional[Tensor] = None
    fs: Optional[float] = None

    def __call__(self, signal: Tensor, rirs: Tensor) -> Tensor:
        return self.convolve(signal, rirs)

    def convolve(self, signal: Tensor, rirs: Tensor) -> Tensor:
        """Convolve signals with time-varying RIRs.

        Examples:
            ```python
            y = DynamicConvolver(mode="hop", hop=1024).convolve(signal, rirs)
            ```
        """
        if self.mode not in ("trajectory", "hop"):
            raise ValueError("mode must be 'trajectory' or 'hop'")
        if self.mode == "hop":
            if self.hop is None:
                raise ValueError("hop must be provided for hop mode")
            return _convolve_dynamic_hop(signal, rirs, self.hop)
        return _convolve_dynamic_trajectory(
            signal, rirs, timestamps=self.timestamps, fs=self.fs
        )

convolve

convolve(signal, rirs)

Convolve signals with time-varying RIRs.

Examples:

y = DynamicConvolver(mode="hop", hop=1024).convolve(signal, rirs)
Source code in src/torchrir/signal/dynamic.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def convolve(self, signal: Tensor, rirs: Tensor) -> Tensor:
    """Convolve signals with time-varying RIRs.

    Examples:
        ```python
        y = DynamicConvolver(mode="hop", hop=1024).convolve(signal, rirs)
        ```
    """
    if self.mode not in ("trajectory", "hop"):
        raise ValueError("mode must be 'trajectory' or 'hop'")
    if self.mode == "hop":
        if self.hop is None:
            raise ValueError("hop must be provided for hop mode")
        return _convolve_dynamic_hop(signal, rirs, self.hop)
    return _convolve_dynamic_trajectory(
        signal, rirs, timestamps=self.timestamps, fs=self.fs
    )

convolve_rir

convolve_rir(signal, rirs)

Convolve signals with static RIRs (supports multi-source/mic).

Parameters:

Name Type Description Default
signal Tensor

(n_src, n_samples) or (n_samples,) tensor.

required
rirs Tensor

(n_src, n_mic, rir_len) or compatible shape.

required

Returns:

Type Description
Tensor

(n_mic, n_samples + rir_len - 1) tensor or 1D for single mic.

Examples:

y = convolve_rir(signal, rirs)
Source code in src/torchrir/signal/static.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def convolve_rir(signal: Tensor, rirs: Tensor) -> Tensor:
    """Convolve signals with static RIRs (supports multi-source/mic).

    Args:
        signal: (n_src, n_samples) or (n_samples,) tensor.
        rirs: (n_src, n_mic, rir_len) or compatible shape.

    Returns:
        (n_mic, n_samples + rir_len - 1) tensor or 1D for single mic.

    Examples:
        ```python
        y = convolve_rir(signal, rirs)
        ```
    """
    signal = _ensure_signal(signal)
    rirs = _ensure_static_rirs(rirs)
    n_src, n_mic, rir_len = rirs.shape

    if signal.shape[0] not in (1, n_src):
        raise ValueError("signal source count does not match rirs")
    if signal.shape[0] == 1 and n_src > 1:
        signal = signal.expand(n_src, -1)

    out_len = signal.shape[1] + rir_len - 1
    out = torch.zeros((n_mic, out_len), dtype=signal.dtype, device=signal.device)

    for s in range(n_src):
        for m in range(n_mic):
            out[m] += fft_convolve(signal[s], rirs[s, m])

    return out.squeeze(0) if n_mic == 1 else out

fft_convolve

fft_convolve(signal, rir)

Convolve a 1D signal with a 1D RIR using FFT.

Parameters:

Name Type Description Default
signal Tensor

1D signal tensor.

required
rir Tensor

1D impulse response.

required

Returns:

Type Description
Tensor

1D tensor of length len(signal) + len(rir) - 1.

Examples:

y = fft_convolve(signal, rir)
Source code in src/torchrir/signal/static.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def fft_convolve(signal: Tensor, rir: Tensor) -> Tensor:
    """Convolve a 1D signal with a 1D RIR using FFT.

    Args:
        signal: 1D signal tensor.
        rir: 1D impulse response.

    Returns:
        1D tensor of length len(signal) + len(rir) - 1.

    Examples:
        ```python
        y = fft_convolve(signal, rir)
        ```
    """
    if signal.ndim != 1 or rir.ndim != 1:
        raise ValueError("fft_convolve expects 1D tensors")
    n = signal.numel() + rir.numel() - 1
    fft_len = 1 << (n - 1).bit_length()
    sig_f = torch.fft.rfft(signal, n=fft_len)
    rir_f = torch.fft.rfft(rir, n=fft_len)
    out = torch.fft.irfft(sig_f * rir_f, n=fft_len)
    return out[:n]

torchrir.geometry

torchrir.geometry

Geometry helpers for arrays, trajectories, and sampling.

Includes standard array layouts (linear, circular, polyhedron, binaural, Eigenmike) plus position sampling utilities.

__all__ module-attribute

__all__ = ['binaural_array', 'circular_array', 'clamp_positions', 'eigenmike_em32', 'eigenmike_em64', 'linear_array', 'linear_trajectory', 'polyhedron_array', 'sample_positions', 'sample_positions_min_distance']

binaural_array

binaural_array(center, *, offset=0.08, device=None, dtype=None)

Create a two-mic binaural layout around a center point.

Source code in src/torchrir/geometry/arrays.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def binaural_array(
    center: Sequence[float] | Tensor,
    *,
    offset: float = 0.08,
    device: torch.device | str | None = None,
    dtype: torch.dtype | None = None,
) -> Tensor:
    """Create a two-mic binaural layout around a center point."""
    center_t = as_tensor(center, device=device, dtype=dtype)
    dim = center_t.numel()
    offset_vec = torch.zeros((dim,), device=center_t.device, dtype=center_t.dtype)
    offset_vec[0] = offset
    left = center_t - offset_vec
    right = center_t + offset_vec
    return torch.stack([left, right], dim=0)

circular_array

circular_array(center, *, num, radius, plane='xy', normal=None, device=None, dtype=None)

Create an equally spaced circular microphone array.

Source code in src/torchrir/geometry/arrays.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def circular_array(
    center: Sequence[float] | Tensor,
    *,
    num: int,
    radius: float,
    plane: str = "xy",
    normal: Sequence[float] | Tensor | None = None,
    device: torch.device | str | None = None,
    dtype: torch.dtype | None = None,
) -> Tensor:
    """Create an equally spaced circular microphone array."""
    if num <= 0:
        raise ValueError("num must be positive")
    if radius <= 0:
        raise ValueError("radius must be positive")
    center_t = as_tensor(center, device=device, dtype=dtype)
    dim = center_t.numel()

    angles = torch.linspace(
        0.0, 2.0 * math.pi, num + 1, device=center_t.device, dtype=center_t.dtype
    )[:-1]
    xy = torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1)

    if dim == 2:
        return center_t + radius * xy
    if dim != 3:
        raise ValueError("center must be 2D or 3D")

    if normal is not None:
        normal_t = as_tensor(normal, device=center_t.device, dtype=center_t.dtype)
        basis_x, basis_y = _basis_from_normal(normal_t)
    else:
        plane_l = plane.lower()
        if plane_l == "xy":
            basis_x = torch.tensor(
                [1.0, 0.0, 0.0], device=center_t.device, dtype=center_t.dtype
            )
            basis_y = torch.tensor(
                [0.0, 1.0, 0.0], device=center_t.device, dtype=center_t.dtype
            )
        elif plane_l == "xz":
            basis_x = torch.tensor(
                [1.0, 0.0, 0.0], device=center_t.device, dtype=center_t.dtype
            )
            basis_y = torch.tensor(
                [0.0, 0.0, 1.0], device=center_t.device, dtype=center_t.dtype
            )
        elif plane_l == "yz":
            basis_x = torch.tensor(
                [0.0, 1.0, 0.0], device=center_t.device, dtype=center_t.dtype
            )
            basis_y = torch.tensor(
                [0.0, 0.0, 1.0], device=center_t.device, dtype=center_t.dtype
            )
        else:
            raise ValueError("plane must be one of 'xy', 'xz', 'yz'")

    circle = xy[:, 0:1] * basis_x[None, :] + xy[:, 1:2] * basis_y[None, :]
    return center_t + radius * circle

clamp_positions

clamp_positions(positions, room_size, margin=0.1)

Clamp positions to remain inside the room with a margin.

Source code in src/torchrir/geometry/sampling.py
89
90
91
92
93
94
95
def clamp_positions(
    positions: Tensor, room_size: Tensor, margin: float = 0.1
) -> Tensor:
    """Clamp positions to remain inside the room with a margin."""
    min_v = torch.full_like(room_size, margin)
    max_v = room_size - margin
    return torch.max(torch.min(positions, max_v), min_v)

eigenmike_em32

eigenmike_em32(center, *, radius=0.042, azimuth_offset_deg=0.0, device=None, dtype=None)

Create the mh acoustics Eigenmike em32 geometry (3D only).

Source code in src/torchrir/geometry/arrays.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def eigenmike_em32(
    center: Sequence[float] | Tensor,
    *,
    radius: float = 0.042,
    azimuth_offset_deg: float = 0.0,
    device: torch.device | str | None = None,
    dtype: torch.dtype | None = None,
) -> Tensor:
    """Create the mh acoustics Eigenmike em32 geometry (3D only)."""
    if radius <= 0:
        raise ValueError("radius must be positive")
    center_t = as_tensor(center, device=device, dtype=dtype)
    if center_t.numel() != 3:
        raise ValueError("Eigenmike em32 requires a 3D center")
    theta_deg = torch.tensor(
        [
            69.0,
            90.0,
            111.0,
            90.0,
            32.0,
            55.0,
            90.0,
            125.0,
            148.0,
            125.0,
            90.0,
            55.0,
            21.0,
            58.0,
            121.0,
            159.0,
            69.0,
            90.0,
            111.0,
            90.0,
            32.0,
            55.0,
            90.0,
            125.0,
            148.0,
            125.0,
            90.0,
            55.0,
            21.0,
            58.0,
            122.0,
            159.0,
        ],
        device=center_t.device,
        dtype=center_t.dtype,
    )
    phi_deg = torch.tensor(
        [
            0.0,
            32.0,
            0.0,
            328.0,
            0.0,
            45.0,
            69.0,
            45.0,
            0.0,
            315.0,
            291.0,
            315.0,
            91.0,
            90.0,
            90.0,
            89.0,
            180.0,
            212.0,
            180.0,
            148.0,
            180.0,
            225.0,
            249.0,
            225.0,
            180.0,
            135.0,
            111.0,
            135.0,
            269.0,
            270.0,
            270.0,
            271.0,
        ],
        device=center_t.device,
        dtype=center_t.dtype,
    )
    return _spherical_array_from_angles(
        center=center_t,
        radius=radius,
        theta_deg=theta_deg,
        phi_deg=phi_deg + azimuth_offset_deg,
    )

eigenmike_em64

eigenmike_em64(center, *, radius=0.042, azimuth_offset_deg=0.0, device=None, dtype=None)

Create the mh acoustics Eigenmike em64 geometry (3D only).

Source code in src/torchrir/geometry/arrays.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
def eigenmike_em64(
    center: Sequence[float] | Tensor,
    *,
    radius: float = 0.042,
    azimuth_offset_deg: float = 0.0,
    device: torch.device | str | None = None,
    dtype: torch.dtype | None = None,
) -> Tensor:
    """Create the mh acoustics Eigenmike em64 geometry (3D only)."""
    if radius <= 0:
        raise ValueError("radius must be positive")
    center_t = as_tensor(center, device=device, dtype=dtype)
    if center_t.numel() != 3:
        raise ValueError("Eigenmike em64 requires a 3D center")
    theta_deg = torch.tensor(
        [
            16.7656,
            21.9677,
            42.3941,
            13.2817,
            22.6728,
            52.6925,
            37.806,
            43.3944,
            43.9386,
            70.3132,
            33.2231,
            60.0257,
            56.4763,
            67.4936,
            93.2735,
            48.423,
            78.0793,
            62.0685,
            38.7171,
            63.8004,
            70.1946,
            96.246,
            81.0992,
            106.094,
            67.7533,
            91.7061,
            39.9985,
            68.7726,
            60.8869,
            82.2833,
            63.0247,
            89.794,
            137.5166,
            139.7604,
            135.2133,
            160.3628,
            162.577,
            142.0685,
            161.1987,
            162.577,
            115.536,
            86.2594,
            116.0164,
            95.3313,
            90.0637,
            111.4549,
            85.8671,
            130.8398,
            102.5775,
            142.6375,
            117.032,
            117.5631,
            115.8884,
            89.69,
            118.4478,
            93.9338,
            106.3875,
            81.0511,
            135.9764,
            142.6771,
            120.6556,
            133.8834,
            116.3591,
            107.464,
        ],
        device=center_t.device,
        dtype=center_t.dtype,
    )
    phi_deg = torch.tensor(
        [
            197.4561,
            115.734,
            81.911,
            313.3592,
            43.1785,
            46.7324,
            335.9958,
            14.5398,
            204.4547,
            206.542,
            247.3219,
            233.817,
            264.5437,
            99.6669,
            104.6842,
            120.9227,
            126.513,
            148.2368,
            162.6381,
            178.5498,
            21.2715,
            25.7834,
            47.8607,
            55.9075,
            71.4285,
            78.4921,
            293.221,
            290.5683,
            318.1354,
            334.0042,
            352.0227,
            0.0,
            174.0335,
            212.7205,
            251.9179,
            150.6471,
            240.8266,
            293.0625,
            331.0098,
            60.8266,
            226.9135,
            233.9255,
            193.6382,
            209.6696,
            183.169,
            163.7105,
            156.9524,
            139.4318,
            135.9729,
            102.3273,
            112.5511,
            83.1464,
            307.7078,
            309.1392,
            278.2519,
            282.9735,
            253.147,
            260.0688,
            59.7394,
            14.2241,
            32.4901,
            334.0753,
            2.0842,
            335.0677,
        ],
        device=center_t.device,
        dtype=center_t.dtype,
    )
    return _spherical_array_from_angles(
        center=center_t,
        radius=radius,
        theta_deg=theta_deg,
        phi_deg=phi_deg + azimuth_offset_deg,
    )

linear_array

linear_array(center, *, num, spacing, axis=0, direction=None, device=None, dtype=None)

Create an equally spaced linear microphone array.

Source code in src/torchrir/geometry/arrays.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def linear_array(
    center: Sequence[float] | Tensor,
    *,
    num: int,
    spacing: float,
    axis: int = 0,
    direction: Sequence[float] | Tensor | None = None,
    device: torch.device | str | None = None,
    dtype: torch.dtype | None = None,
) -> Tensor:
    """Create an equally spaced linear microphone array."""
    if num <= 0:
        raise ValueError("num must be positive")
    if spacing <= 0:
        raise ValueError("spacing must be positive")
    center_t = as_tensor(center, device=device, dtype=dtype)
    dim = center_t.numel()
    if direction is None:
        if axis < 0 or axis >= dim:
            raise ValueError("axis out of range for center dimensionality")
        direction_vec = torch.zeros(
            (dim,), device=center_t.device, dtype=center_t.dtype
        )
        direction_vec[axis] = 1.0
    else:
        direction_vec = as_tensor(
            direction, device=center_t.device, dtype=center_t.dtype
        )
        if direction_vec.numel() != dim:
            raise ValueError("direction must match center dimensionality")
        direction_vec = direction_vec / torch.linalg.norm(direction_vec)

    offsets = (
        torch.arange(num, device=center_t.device, dtype=center_t.dtype)
        - (num - 1) / 2.0
    ) * spacing
    return center_t + offsets[:, None] * direction_vec[None, :]

linear_trajectory

linear_trajectory(start, end, steps)

Create a linear trajectory between start and end.

Source code in src/torchrir/geometry/trajectories.py
 9
10
11
12
13
14
def linear_trajectory(start: Tensor, end: Tensor, steps: int) -> Tensor:
    """Create a linear trajectory between start and end."""
    return torch.stack(
        [start + (end - start) * t / (steps - 1) for t in range(steps)],
        dim=0,
    )

polyhedron_array

polyhedron_array(center, *, kind='tetrahedron', radius=0.1, device=None, dtype=None)

Create a regular polyhedron microphone array (3D only).

Source code in src/torchrir/geometry/arrays.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def polyhedron_array(
    center: Sequence[float] | Tensor,
    *,
    kind: str = "tetrahedron",
    radius: float = 0.1,
    device: torch.device | str | None = None,
    dtype: torch.dtype | None = None,
) -> Tensor:
    """Create a regular polyhedron microphone array (3D only)."""
    if radius <= 0:
        raise ValueError("radius must be positive")
    center_t = as_tensor(center, device=device, dtype=dtype)
    if center_t.numel() != 3:
        raise ValueError("polyhedron arrays require 3D centers")
    vertices = _polyhedron_vertices(kind, device=center_t.device, dtype=center_t.dtype)
    norms = torch.linalg.norm(vertices, dim=-1, keepdim=True)
    vertices = vertices / norms
    return center_t + radius * vertices

sample_positions

sample_positions(*, num, room_size, rng, margin=0.5)

Sample random positions within a room with a safety margin.

Source code in src/torchrir/geometry/sampling.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def sample_positions(
    *,
    num: int,
    room_size: Tensor,
    rng: random.Random,
    margin: float = 0.5,
) -> Tensor:
    """Sample random positions within a room with a safety margin."""
    dim = room_size.numel()
    low = [margin] * dim
    high = [float(room_size[i].item()) - margin for i in range(dim)]
    coords: List[List[float]] = []
    for _ in range(num):
        point = [rng.uniform(low[i], high[i]) for i in range(dim)]
        coords.append(point)
    return torch.tensor(coords, dtype=torch.float32)

sample_positions_min_distance

sample_positions_min_distance(*, num, room_size, rng, center, min_distance, z_range=(1.5, 1.8), margin=0.5, max_attempts=1000)

Sample random positions with a minimum distance from a center point.

Source code in src/torchrir/geometry/sampling.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def sample_positions_min_distance(
    *,
    num: int,
    room_size: Tensor,
    rng: random.Random,
    center: Tensor,
    min_distance: float,
    z_range: tuple[float, float] | None = (1.5, 1.8),
    margin: float = 0.5,
    max_attempts: int = 1000,
) -> Tensor:
    """Sample random positions with a minimum distance from a center point."""
    dim = room_size.numel()
    center = center.to(dtype=torch.float32).reshape(-1)
    if center.numel() != dim:
        raise ValueError("center dimension must match room_size.")
    low = [margin] * dim
    high = [float(room_size[i].item()) - margin for i in range(dim)]
    coords: List[List[float]] = []
    attempts = 0
    while len(coords) < num and attempts < max_attempts:
        attempts += 1
        point = [rng.uniform(low[i], high[i]) for i in range(dim)]
        if z_range is not None and dim >= 3:
            z_min, z_max = z_range
            z_low = max(margin, float(z_min))
            z_high = min(float(room_size[2].item()) - margin, float(z_max))
            if z_high > z_low:
                point[2] = rng.uniform(z_low, z_high)
        dist = torch.linalg.norm(torch.tensor(point) - center).item()
        if dist >= min_distance:
            coords.append(point)
    if len(coords) < num:
        raise RuntimeError("failed to sample positions with requested minimum distance")
    return torch.tensor(coords, dtype=torch.float32)

torchrir.viz

torchrir.viz

Visualization helpers for scenes and trajectories.

Provides static/dynamic plotting plus GIF/MP4 animation utilities.

__all__ module-attribute

__all__ = ['animate_scene_gif', 'animate_scene_mp4', 'save_scene_gifs', 'save_scene_layout_images', 'save_scene_videos', 'save_scene_plots', 'plot_scene_dynamic', 'plot_scene_static', 'render_scene_plots']

animate_scene_gif

animate_scene_gif(*, out_path, room, sources, mics, src_traj=None, mic_traj=None, step=1, fps=None, signal_len=None, fs=None, duration_s=None, plot_2d=True, plot_3d=False, annotate_sources=True, annotation_lines=None)

Render a GIF showing source/mic trajectories.

Source code in src/torchrir/viz/animation.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def animate_scene_gif(
    *,
    out_path: Path,
    room: Sequence[float] | torch.Tensor,
    sources: object | torch.Tensor | Sequence,
    mics: object | torch.Tensor | Sequence,
    src_traj: Optional[torch.Tensor | Sequence] = None,
    mic_traj: Optional[torch.Tensor | Sequence] = None,
    step: int = 1,
    fps: Optional[float] = None,
    signal_len: Optional[int] = None,
    fs: Optional[float] = None,
    duration_s: Optional[float] = None,
    plot_2d: bool = True,
    plot_3d: bool = False,
    annotate_sources: bool = True,
    annotation_lines: Optional[Sequence[str]] = None,
) -> Path:
    """Render a GIF showing source/mic trajectories."""
    import matplotlib.pyplot as plt

    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    fig, anim, fps_out = _build_scene_animation(
        room=room,
        sources=sources,
        mics=mics,
        src_traj=src_traj,
        mic_traj=mic_traj,
        step=step,
        fps=fps,
        signal_len=signal_len,
        fs=fs,
        duration_s=duration_s,
        plot_2d=plot_2d,
        plot_3d=plot_3d,
        annotate_sources=annotate_sources,
        annotation_lines=annotation_lines,
    )
    anim.save(out_path, writer="pillow", fps=fps_out)
    plt.close(fig)
    return out_path

animate_scene_mp4

animate_scene_mp4(*, out_path, room, sources, mics, src_traj=None, mic_traj=None, step=1, fps=None, signal_len=None, fs=None, duration_s=None, plot_2d=True, plot_3d=False, annotate_sources=True, annotation_lines=None, mixture_path=None, mux_audio=True, audio_channels=(0, 1))

Render an MP4 showing source/mic trajectories.

When mux_audio is enabled and mixture_path is given, a stereo track is added with ffmpeg using the requested channel indices. The video canvas defaults to HD (1280x720).

Source code in src/torchrir/viz/animation.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def animate_scene_mp4(
    *,
    out_path: Path,
    room: Sequence[float] | torch.Tensor,
    sources: object | torch.Tensor | Sequence,
    mics: object | torch.Tensor | Sequence,
    src_traj: Optional[torch.Tensor | Sequence] = None,
    mic_traj: Optional[torch.Tensor | Sequence] = None,
    step: int = 1,
    fps: Optional[float] = None,
    signal_len: Optional[int] = None,
    fs: Optional[float] = None,
    duration_s: Optional[float] = None,
    plot_2d: bool = True,
    plot_3d: bool = False,
    annotate_sources: bool = True,
    annotation_lines: Optional[Sequence[str]] = None,
    mixture_path: Path | None = None,
    mux_audio: bool = True,
    audio_channels: tuple[int, int] = (0, 1),
) -> Path:
    """Render an MP4 showing source/mic trajectories.

    When ``mux_audio`` is enabled and ``mixture_path`` is given, a stereo track
    is added with ffmpeg using the requested channel indices.
    The video canvas defaults to HD (1280x720).
    """
    import matplotlib.pyplot as plt
    from matplotlib.animation import FFMpegWriter

    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    fig, anim, fps_out = _build_scene_animation(
        room=room,
        sources=sources,
        mics=mics,
        src_traj=src_traj,
        mic_traj=mic_traj,
        step=step,
        fps=fps,
        signal_len=signal_len,
        fs=fs,
        duration_s=duration_s,
        plot_2d=plot_2d,
        plot_3d=plot_3d,
        annotate_sources=annotate_sources,
        annotation_lines=annotation_lines,
        figsize=_MP4_FIGSIZE_INCHES,
    )
    writer = FFMpegWriter(fps=fps_out)
    anim.save(out_path, writer=writer, dpi=_MP4_DPI)
    plt.close(fig)

    if mux_audio and mixture_path is not None:
        _add_stereo_audio_to_mp4(
            video_path=out_path,
            mixture_path=Path(mixture_path),
            audio_channels=audio_channels,
        )
    return out_path

plot_scene_dynamic

plot_scene_dynamic(*, room, src_traj, mic_traj, step=1, src_pos=None, mic_pos=None, ax=None, title=None, show=False, annotate_sources=True, annotation_lines=None)

Plot source and mic trajectories within a room.

If trajectories are static, only positions are plotted.

Examples:

ax = plot_scene_dynamic(
    room=[6.0, 4.0, 3.0],
    src_traj=src_traj,
    mic_traj=mic_traj,
)
Source code in src/torchrir/viz/scene.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def plot_scene_dynamic(
    *,
    room: Room | Sequence[float] | Tensor,
    src_traj: Tensor | Sequence,
    mic_traj: Tensor | Sequence,
    step: int = 1,
    src_pos: Optional[Tensor | Sequence] = None,
    mic_pos: Optional[Tensor | Sequence] = None,
    ax: Any | None = None,
    title: Optional[str] = None,
    show: bool = False,
    annotate_sources: bool = True,
    annotation_lines: Optional[Sequence[str]] = None,
):
    """Plot source and mic trajectories within a room.

    If trajectories are static, only positions are plotted.

    Examples:
        ```python
        ax = plot_scene_dynamic(
            room=[6.0, 4.0, 3.0],
            src_traj=src_traj,
            mic_traj=mic_traj,
        )
        ```
    """
    plt, ax = _setup_axes(ax, room)

    size = _room_size(room, ax)
    _draw_room(ax, size)

    src_traj = _as_trajectory(src_traj)
    mic_traj = _as_trajectory(mic_traj)
    src_pos_t = _extract_positions(src_pos, ax) if src_pos is not None else src_traj[0]
    mic_pos_t = _extract_positions(mic_pos, ax) if mic_pos is not None else mic_traj[0]

    _plot_entity(ax, src_traj, src_pos_t, step=step, label="sources", marker="^")
    _plot_entity(
        ax,
        mic_traj,
        mic_pos_t,
        step=step,
        label="mics",
        marker="o",
        color=_MIC_COLOR,
        uniform_color=True,
    )
    if annotate_sources:
        _annotate_source_indices(ax, src_pos_t)
    _add_axes_annotation(ax, annotation_lines)

    if title:
        ax.set_title(title)
    ax.legend(loc="best")
    if show:
        plt.show()
    return ax

plot_scene_static

plot_scene_static(*, room, sources, mics, ax=None, title=None, show=False, annotate_sources=True, annotation_lines=None)

Plot a static room with source and mic positions.

Examples:

ax = plot_scene_static(
    room=[6.0, 4.0, 3.0],
    sources=[[1.0, 2.0, 1.5]],
    mics=[[2.0, 2.0, 1.5]],
)
Source code in src/torchrir/viz/scene.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def plot_scene_static(
    *,
    room: Room | Sequence[float] | Tensor,
    sources: Source | Tensor | Sequence,
    mics: MicrophoneArray | Tensor | Sequence,
    ax: Any | None = None,
    title: Optional[str] = None,
    show: bool = False,
    annotate_sources: bool = True,
    annotation_lines: Optional[Sequence[str]] = None,
):
    """Plot a static room with source and mic positions.

    Examples:
        ```python
        ax = plot_scene_static(
            room=[6.0, 4.0, 3.0],
            sources=[[1.0, 2.0, 1.5]],
            mics=[[2.0, 2.0, 1.5]],
        )
        ```
    """
    plt, ax = _setup_axes(ax, room)

    size = _room_size(room, ax)
    _draw_room(ax, size)

    src = _extract_positions(sources, ax)
    mic = _extract_positions(mics, ax)

    _scatter_positions(ax, src, label="sources", marker="^")
    _scatter_positions(ax, mic, label="mics", marker="o", color=_MIC_COLOR)
    if annotate_sources:
        _annotate_source_indices(ax, src)
    _add_axes_annotation(ax, annotation_lines)

    if title:
        ax.set_title(title)
    ax.legend(loc="best")
    if show:
        plt.show()
    return ax

render_scene_plots

render_scene_plots(*, out_dir, room, sources, mics, src_traj=None, mic_traj=None, prefix='scene', step=1, show=False, plot_2d=True, plot_3d=True, annotate_sources=True, annotation_lines=None)

Plot static and dynamic scenes and save images to disk.

Source code in src/torchrir/viz/io.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def render_scene_plots(
    *,
    out_dir: Path,
    room: Sequence[float] | torch.Tensor,
    sources: object | torch.Tensor | Sequence,
    mics: object | torch.Tensor | Sequence,
    src_traj: Optional[torch.Tensor | Sequence] = None,
    mic_traj: Optional[torch.Tensor | Sequence] = None,
    prefix: str = "scene",
    step: int = 1,
    show: bool = False,
    plot_2d: bool = True,
    plot_3d: bool = True,
    annotate_sources: bool = True,
    annotation_lines: Optional[Sequence[str]] = None,
) -> tuple[list[Path], list[Path]]:
    """Plot static and dynamic scenes and save images to disk."""
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    room_size = _to_cpu(room)
    src_pos = _positions_to_cpu(sources)
    mic_pos = _positions_to_cpu(mics)
    dim = int(room_size.numel())

    static_paths: list[Path] = []
    dynamic_paths: list[Path] = []

    for view_dim, enabled in ((2, plot_2d), (3, plot_3d)):
        if not enabled:
            continue
        if view_dim == 2 and dim < 2:
            continue
        if view_dim == 3 and dim < 3:
            continue
        view_room = room_size[:view_dim]
        view_src = src_pos[:, :view_dim]
        view_mic = mic_pos[:, :view_dim]

        ax = plot_scene_static(
            room=view_room,
            sources=view_src,
            mics=view_mic,
            title=f"Room scene ({view_dim}D static)",
            show=False,
            annotate_sources=annotate_sources,
            annotation_lines=annotation_lines,
        )
        static_path = out_dir / f"{prefix}_static_{view_dim}d.png"
        _save_axes(ax, static_path, show=show)
        static_paths.append(static_path)

        if src_traj is not None or mic_traj is not None:
            steps = _traj_steps(src_traj, mic_traj)
            src_traj = _trajectory_to_cpu(src_traj, src_pos, steps)
            mic_traj = _trajectory_to_cpu(mic_traj, mic_pos, steps)
            view_src_traj = src_traj[:, :, :view_dim]
            view_mic_traj = mic_traj[:, :, :view_dim]
            ax = plot_scene_dynamic(
                room=view_room,
                src_traj=view_src_traj,
                mic_traj=view_mic_traj,
                src_pos=view_src,
                mic_pos=view_mic,
                step=step,
                title=f"Room scene ({view_dim}D trajectories)",
                show=False,
                annotate_sources=annotate_sources,
                annotation_lines=annotation_lines,
            )
            dynamic_path = out_dir / f"{prefix}_dynamic_{view_dim}d.png"
            _save_axes(ax, dynamic_path, show=show)
            dynamic_paths.append(dynamic_path)

    return static_paths, dynamic_paths

save_scene_gifs

save_scene_gifs(*, out_dir, room, sources, mics, src_traj, mic_traj, prefix, signal_len, fs, gif_fps, logger, annotate_sources=True, annotation_lines=None)

Render trajectory GIFs.

Source code in src/torchrir/viz/io.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def save_scene_gifs(
    *,
    out_dir: Path,
    room: torch.Tensor | Sequence[float],
    sources: object,
    mics: object,
    src_traj: torch.Tensor,
    mic_traj: torch.Tensor,
    prefix: str,
    signal_len: int,
    fs: int,
    gif_fps: int,
    logger: logging.Logger,
    annotate_sources: bool = True,
    annotation_lines: Optional[Sequence[str]] = None,
) -> None:
    """Render trajectory GIFs."""
    try:
        gif_path = out_dir / f"{prefix}.gif"
        animate_scene_gif(
            out_path=gif_path,
            room=room,
            sources=sources,
            mics=mics,
            src_traj=src_traj,
            mic_traj=mic_traj,
            fps=gif_fps if gif_fps > 0 else None,
            signal_len=signal_len,
            fs=fs,
            annotate_sources=annotate_sources,
            annotation_lines=annotation_lines,
        )
        logger.info("saved: %s", gif_path)
        if torch.as_tensor(room).numel() == 3:
            gif_path_3d = out_dir / f"{prefix}_3d.gif"
            animate_scene_gif(
                out_path=gif_path_3d,
                room=room,
                sources=sources,
                mics=mics,
                src_traj=src_traj,
                mic_traj=mic_traj,
                fps=gif_fps if gif_fps > 0 else None,
                signal_len=signal_len,
                fs=fs,
                plot_2d=False,
                plot_3d=True,
                annotate_sources=annotate_sources,
                annotation_lines=annotation_lines,
            )
            logger.info("saved: %s", gif_path_3d)
    except Exception as exc:  # pragma: no cover - optional dependency
        logger.warning("GIF skipped: %s", exc)

save_scene_layout_images

save_scene_layout_images(*, out_dir, room, sources, mics, logger, src_traj=None, mic_traj=None, save_2d=True, save_3d=True, annotate_sources=True, annotation_lines=None, show=False)

Save static layout images with explicit 2D/3D filenames.

Source code in src/torchrir/viz/io.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
def save_scene_layout_images(
    *,
    out_dir: Path,
    room: torch.Tensor | Sequence[float],
    sources: object,
    mics: object,
    logger: logging.Logger,
    src_traj: torch.Tensor | Sequence | None = None,
    mic_traj: torch.Tensor | Sequence | None = None,
    save_2d: bool = True,
    save_3d: bool = True,
    annotate_sources: bool = True,
    annotation_lines: Optional[Sequence[str]] = None,
    show: bool = False,
) -> None:
    """Save static layout images with explicit 2D/3D filenames."""
    try:
        room_size = _to_cpu(room)
        src_pos = _positions_to_cpu(sources)
        mic_pos = _positions_to_cpu(mics)
        dim = int(room_size.numel())

        out_dir.mkdir(parents=True, exist_ok=True)

        has_traj = src_traj is not None or mic_traj is not None
        src_traj_t = None
        mic_traj_t = None
        if has_traj:
            steps = _traj_steps(src_traj, mic_traj)
            src_traj_t = _trajectory_to_cpu(src_traj, src_pos, steps)
            mic_traj_t = _trajectory_to_cpu(mic_traj, mic_pos, steps)

        if save_2d and dim >= 2:
            if has_traj and src_traj_t is not None and mic_traj_t is not None:
                ax2d = plot_scene_dynamic(
                    room=room_size[:2],
                    src_traj=src_traj_t[:, :, :2],
                    mic_traj=mic_traj_t[:, :, :2],
                    src_pos=src_pos[:, :2],
                    mic_pos=mic_pos[:, :2],
                    title="Room layout and source trajectories (top view)",
                    show=False,
                    annotate_sources=annotate_sources,
                    annotation_lines=annotation_lines,
                )
            else:
                ax2d = plot_scene_static(
                    room=room_size[:2],
                    sources=src_pos[:, :2],
                    mics=mic_pos[:, :2],
                    title="Room layout (top view)",
                    show=False,
                    annotate_sources=annotate_sources,
                    annotation_lines=annotation_lines,
                )
            path_2d = out_dir / "room_layout_2d.png"
            _save_axes(ax2d, path_2d, show=show)
            logger.info("saved: %s", path_2d)

        if save_3d and dim >= 3:
            if has_traj and src_traj_t is not None and mic_traj_t is not None:
                ax3d = plot_scene_dynamic(
                    room=room_size[:3],
                    src_traj=src_traj_t[:, :, :3],
                    mic_traj=mic_traj_t[:, :, :3],
                    src_pos=src_pos[:, :3],
                    mic_pos=mic_pos[:, :3],
                    title="Room layout and source trajectories",
                    show=False,
                    annotate_sources=annotate_sources,
                    annotation_lines=annotation_lines,
                )
            else:
                ax3d = plot_scene_static(
                    room=room_size[:3],
                    sources=src_pos[:, :3],
                    mics=mic_pos[:, :3],
                    title="Room layout",
                    show=False,
                    annotate_sources=annotate_sources,
                    annotation_lines=annotation_lines,
                )
            path_3d = out_dir / "room_layout_3d.png"
            _save_axes(ax3d, path_3d, show=show)
            logger.info("saved: %s", path_3d)
    except Exception as exc:  # pragma: no cover - optional dependency
        logger.warning("Layout image skipped: %s", exc)

save_scene_plots

save_scene_plots(*, out_dir, room, sources, mics, src_traj=None, mic_traj=None, prefix, show, logger, plot_2d=True, plot_3d=True, annotate_sources=True, annotation_lines=None)

Plot and save scene images.

Source code in src/torchrir/viz/io.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def save_scene_plots(
    *,
    out_dir: Path,
    room: torch.Tensor | Sequence[float],
    sources: object,
    mics: object,
    src_traj: Optional[torch.Tensor | Sequence] = None,
    mic_traj: Optional[torch.Tensor | Sequence] = None,
    prefix: str,
    show: bool,
    logger: logging.Logger,
    plot_2d: bool = True,
    plot_3d: bool = True,
    annotate_sources: bool = True,
    annotation_lines: Optional[Sequence[str]] = None,
) -> None:
    """Plot and save scene images."""
    try:
        static_paths, dynamic_paths = render_scene_plots(
            out_dir=out_dir,
            room=room,
            sources=sources,
            mics=mics,
            src_traj=src_traj,
            mic_traj=mic_traj,
            prefix=prefix,
            show=show,
            plot_2d=plot_2d,
            plot_3d=plot_3d,
            annotate_sources=annotate_sources,
            annotation_lines=annotation_lines,
        )
        for path in static_paths + dynamic_paths:
            logger.info("saved: %s", path)
    except Exception as exc:  # pragma: no cover - optional dependency
        logger.warning("Plot skipped: %s", exc)

save_scene_videos

save_scene_videos(*, out_dir, room, sources, mics, src_traj, mic_traj, signal_len, fs, logger, mp4_fps=None, save_3d=True, mixture_path=None, mux_audio=True, annotate_sources=True, annotation_lines=None)

Render trajectory MP4 videos.

Output names follow oobss-compatible conventions: - room_layout_2d.mp4 - room_layout_3d.mp4 (3D rooms when save_3d is enabled)

Source code in src/torchrir/viz/io.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def save_scene_videos(
    *,
    out_dir: Path,
    room: torch.Tensor | Sequence[float],
    sources: object,
    mics: object,
    src_traj: torch.Tensor,
    mic_traj: torch.Tensor,
    signal_len: int,
    fs: int,
    logger: logging.Logger,
    mp4_fps: float | None = None,
    save_3d: bool = True,
    mixture_path: Path | None = None,
    mux_audio: bool = True,
    annotate_sources: bool = True,
    annotation_lines: Optional[Sequence[str]] = None,
) -> None:
    """Render trajectory MP4 videos.

    Output names follow oobss-compatible conventions:
    - ``room_layout_2d.mp4``
    - ``room_layout_3d.mp4`` (3D rooms when ``save_3d`` is enabled)
    """
    try:
        path_2d = out_dir / "room_layout_2d.mp4"
        animate_scene_mp4(
            out_path=path_2d,
            room=room,
            sources=sources,
            mics=mics,
            src_traj=src_traj,
            mic_traj=mic_traj,
            fps=mp4_fps,
            signal_len=signal_len,
            fs=fs,
            plot_2d=True,
            plot_3d=False,
            annotate_sources=annotate_sources,
            annotation_lines=annotation_lines,
            mixture_path=mixture_path,
            mux_audio=mux_audio,
        )
        logger.info("saved: %s", path_2d)

        if torch.as_tensor(room).numel() == 3 and save_3d:
            path_3d = out_dir / "room_layout_3d.mp4"
            animate_scene_mp4(
                out_path=path_3d,
                room=room,
                sources=sources,
                mics=mics,
                src_traj=src_traj,
                mic_traj=mic_traj,
                fps=mp4_fps,
                signal_len=signal_len,
                fs=fs,
                plot_2d=False,
                plot_3d=True,
                annotate_sources=annotate_sources,
                annotation_lines=annotation_lines,
                mixture_path=mixture_path,
                mux_audio=mux_audio,
            )
            logger.info("saved: %s", path_3d)
    except Exception as exc:  # pragma: no cover - optional dependency
        logger.warning("MP4 skipped: %s", exc)

torchrir.models

torchrir.models

Core data models for rooms, sources, microphones, scenes, and results.

Examples:

from torchrir import DynamicScene, RIRResult
scene = DynamicScene(room=room, sources=sources, mics=mics, src_traj=src_traj, mic_traj=mic_traj)
result = RIRResult(rirs=rirs, scene=scene, config=config)

SceneLike module-attribute

SceneLike = StaticScene | DynamicScene | Scene

__all__ module-attribute

__all__ = ['DynamicScene', 'MicrophoneArray', 'Room', 'RIRResult', 'Scene', 'SceneLike', 'StaticScene', 'Source']

DynamicScene dataclass

Container for dynamic scene simulation inputs.

Examples:

scene = DynamicScene(room=room, sources=sources, mics=mics, src_traj=src_traj, mic_traj=mic_traj)
Source code in src/torchrir/models/scene.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@dataclass(frozen=True)
class DynamicScene:
    """Container for dynamic scene simulation inputs.

    Examples:
        ```python
        scene = DynamicScene(room=room, sources=sources, mics=mics, src_traj=src_traj, mic_traj=mic_traj)
        ```
    """

    room: Room
    sources: Source
    mics: MicrophoneArray
    src_traj: Tensor
    mic_traj: Tensor

    def __post_init__(self) -> None:
        src_traj = as_tensor(self.src_traj)
        mic_traj = as_tensor(self.mic_traj)
        object.__setattr__(self, "src_traj", src_traj)
        object.__setattr__(self, "mic_traj", mic_traj)
        self._validate_internal()

    def is_dynamic(self) -> bool:
        return True

    def validate(self) -> None:
        self._validate_internal()

    def _validate_internal(self) -> None:
        _validate_scene_entities(self.room, self.sources, self.mics)
        dim = int(self.room.size.numel())
        n_src = int(self.sources.positions.shape[0])
        n_mic = int(self.mics.positions.shape[0])
        t_src = _validate_traj(self.src_traj, n_src, dim, "src_traj")
        t_mic = _validate_traj(self.mic_traj, n_mic, dim, "mic_traj")
        if t_src != t_mic:
            raise ValueError("src_traj and mic_traj must have matching time steps")

MicrophoneArray dataclass

Microphone array container.

Examples:

mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
Source code in src/torchrir/models/room.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
@dataclass(frozen=True)
class MicrophoneArray:
    """Microphone array container.

    Examples:
        ```python
        mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
        ```
    """

    positions: Tensor
    orientation: Optional[Tensor] = None

    def __post_init__(self) -> None:
        pos = _normalize_entity_positions(self.positions, name="mic")
        object.__setattr__(self, "positions", pos)
        ori = _normalize_entity_orientation(
            self.orientation, n_entities=pos.shape[0], dim=pos.shape[1], name="mic"
        )
        if ori is not None:
            object.__setattr__(self, "orientation", ori)

    def replace(self, **kwargs) -> "MicrophoneArray":
        """Return a new MicrophoneArray with updated fields."""
        return replace(self, **kwargs)

    @classmethod
    def from_positions(
        cls,
        positions: Sequence[Sequence[float]] | Tensor,
        *,
        orientation: Optional[Sequence[float] | Tensor] = None,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> "MicrophoneArray":
        """Convert positions/orientation to tensors and build a MicrophoneArray."""
        pos = as_tensor(positions, device=device, dtype=dtype)
        ori = None
        if orientation is not None:
            ori = as_tensor(orientation, device=device, dtype=dtype)
        return cls(pos, ori)

from_positions classmethod

from_positions(positions, *, orientation=None, device=None, dtype=None)

Convert positions/orientation to tensors and build a MicrophoneArray.

Source code in src/torchrir/models/room.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
@classmethod
def from_positions(
    cls,
    positions: Sequence[Sequence[float]] | Tensor,
    *,
    orientation: Optional[Sequence[float] | Tensor] = None,
    device: Optional[torch.device | str] = None,
    dtype: Optional[torch.dtype] = None,
) -> "MicrophoneArray":
    """Convert positions/orientation to tensors and build a MicrophoneArray."""
    pos = as_tensor(positions, device=device, dtype=dtype)
    ori = None
    if orientation is not None:
        ori = as_tensor(orientation, device=device, dtype=dtype)
    return cls(pos, ori)

replace

replace(**kwargs)

Return a new MicrophoneArray with updated fields.

Source code in src/torchrir/models/room.py
154
155
156
def replace(self, **kwargs) -> "MicrophoneArray":
    """Return a new MicrophoneArray with updated fields."""
    return replace(self, **kwargs)

RIRResult dataclass

Container for RIRs with metadata.

Examples:

from torchrir.sim import ISMSimulator
result = ISMSimulator(max_order=6, tmax=0.3).simulate(scene, config)
rirs = result.rirs
Source code in src/torchrir/models/results.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
@dataclass(frozen=True)
class RIRResult:
    """Container for RIRs with metadata.

    Examples:
        ```python
        from torchrir.sim import ISMSimulator
        result = ISMSimulator(max_order=6, tmax=0.3).simulate(scene, config)
        rirs = result.rirs
        ```
    """

    rirs: Tensor
    scene: SceneLike
    config: "SimulationConfig"
    timestamps: Optional[Tensor] = None
    seed: Optional[int] = None

Room dataclass

Room geometry and acoustic parameters.

Examples:

room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
Source code in src/torchrir/models/room.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@dataclass(frozen=True)
class Room:
    """Room geometry and acoustic parameters.

    Examples:
        ```python
        room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
        ```
    """

    size: Tensor
    fs: float
    c: float = 343.0
    beta: Optional[Tensor] = None
    t60: Optional[float] = None

    def __post_init__(self) -> None:
        """Validate room size and reflection parameters."""
        size = ensure_dim(self.size)
        if not torch.all(torch.isfinite(size)):
            raise ValueError("room size must contain finite values")
        if torch.any(size <= 0):
            raise ValueError("room size must be strictly positive")
        object.__setattr__(self, "size", size)
        if self.fs <= 0:
            raise ValueError("fs must be positive")
        if self.c <= 0:
            raise ValueError("c must be positive")
        if self.beta is not None and self.t60 is not None:
            raise ValueError("beta and t60 are mutually exclusive")
        if self.t60 is not None and self.t60 <= 0:
            raise ValueError("t60 must be positive")
        if self.beta is not None:
            beta = as_tensor(self.beta, dtype=size.dtype).view(-1)
            expected = 4 if size.numel() == 2 else 6
            if beta.numel() != expected:
                raise ValueError(
                    f"beta must have {expected} elements for {size.numel()}D rooms"
                )
            if not torch.all(torch.isfinite(beta)):
                raise ValueError("beta must contain finite values")
            if torch.any(beta < 0) or torch.any(beta > 1):
                raise ValueError("beta values must be in [0, 1]")
            object.__setattr__(self, "beta", beta)

    def replace(self, **kwargs) -> "Room":
        """Return a new Room with updated fields."""
        return replace(self, **kwargs)

    @staticmethod
    def shoebox(
        size: Sequence[float] | Tensor,
        *,
        fs: float,
        c: float = 343.0,
        beta: Optional[Sequence[float] | Tensor] = None,
        t60: Optional[float] = None,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> "Room":
        """Create a rectangular (shoebox) room.

        Examples:
            ```python
            room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
            ```
        """
        size_t = as_tensor(size, device=device, dtype=dtype)
        size_t = ensure_dim(size_t)
        beta_t = None
        if beta is not None:
            beta_t = as_tensor(beta, device=device, dtype=dtype)
        return Room(size=size_t, fs=fs, c=c, beta=beta_t, t60=t60)

__post_init__

__post_init__()

Validate room size and reflection parameters.

Source code in src/torchrir/models/room.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def __post_init__(self) -> None:
    """Validate room size and reflection parameters."""
    size = ensure_dim(self.size)
    if not torch.all(torch.isfinite(size)):
        raise ValueError("room size must contain finite values")
    if torch.any(size <= 0):
        raise ValueError("room size must be strictly positive")
    object.__setattr__(self, "size", size)
    if self.fs <= 0:
        raise ValueError("fs must be positive")
    if self.c <= 0:
        raise ValueError("c must be positive")
    if self.beta is not None and self.t60 is not None:
        raise ValueError("beta and t60 are mutually exclusive")
    if self.t60 is not None and self.t60 <= 0:
        raise ValueError("t60 must be positive")
    if self.beta is not None:
        beta = as_tensor(self.beta, dtype=size.dtype).view(-1)
        expected = 4 if size.numel() == 2 else 6
        if beta.numel() != expected:
            raise ValueError(
                f"beta must have {expected} elements for {size.numel()}D rooms"
            )
        if not torch.all(torch.isfinite(beta)):
            raise ValueError("beta must contain finite values")
        if torch.any(beta < 0) or torch.any(beta > 1):
            raise ValueError("beta values must be in [0, 1]")
        object.__setattr__(self, "beta", beta)

replace

replace(**kwargs)

Return a new Room with updated fields.

Source code in src/torchrir/models/room.py
59
60
61
def replace(self, **kwargs) -> "Room":
    """Return a new Room with updated fields."""
    return replace(self, **kwargs)

shoebox staticmethod

shoebox(size, *, fs, c=343.0, beta=None, t60=None, device=None, dtype=None)

Create a rectangular (shoebox) room.

Examples:

room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
Source code in src/torchrir/models/room.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@staticmethod
def shoebox(
    size: Sequence[float] | Tensor,
    *,
    fs: float,
    c: float = 343.0,
    beta: Optional[Sequence[float] | Tensor] = None,
    t60: Optional[float] = None,
    device: Optional[torch.device | str] = None,
    dtype: Optional[torch.dtype] = None,
) -> "Room":
    """Create a rectangular (shoebox) room.

    Examples:
        ```python
        room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
        ```
    """
    size_t = as_tensor(size, device=device, dtype=dtype)
    size_t = ensure_dim(size_t)
    beta_t = None
    if beta is not None:
        beta_t = as_tensor(beta, device=device, dtype=dtype)
    return Room(size=size_t, fs=fs, c=c, beta=beta_t, t60=t60)

Scene dataclass

Deprecated scene wrapper.

Scene is kept for backward compatibility. Prefer StaticScene and DynamicScene to avoid ambiguous states.

Source code in src/torchrir/models/scene.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
@dataclass(frozen=True)
class Scene:
    """Deprecated scene wrapper.

    `Scene` is kept for backward compatibility. Prefer `StaticScene` and
    `DynamicScene` to avoid ambiguous states.
    """

    room: Room
    sources: Source
    mics: MicrophoneArray
    src_traj: Optional[Tensor] = None
    mic_traj: Optional[Tensor] = None

    def __post_init__(self) -> None:
        warnings.warn(
            "Scene is deprecated and will be removed in a future release. "
            "Use StaticScene or DynamicScene.",
            DeprecationWarning,
            stacklevel=2,
        )
        self._validate_internal()

    def _validate_internal(self) -> None:
        _validate_scene_entities(self.room, self.sources, self.mics)
        has_src = self.src_traj is not None
        has_mic = self.mic_traj is not None
        if has_src != has_mic:
            raise ValueError(
                "Scene requires both src_traj and mic_traj for dynamic scenes. "
                "Use StaticScene for static inputs."
            )
        if has_src and has_mic:
            assert self.src_traj is not None
            assert self.mic_traj is not None
            dim = int(self.room.size.numel())
            n_src = int(self.sources.positions.shape[0])
            n_mic = int(self.mics.positions.shape[0])
            t_src = _validate_traj(self.src_traj, n_src, dim, "src_traj")
            t_mic = _validate_traj(self.mic_traj, n_mic, dim, "mic_traj")
            if t_src != t_mic:
                raise ValueError("src_traj and mic_traj must have matching time steps")

    def is_dynamic(self) -> bool:
        return self.src_traj is not None and self.mic_traj is not None

    def validate(self) -> None:
        self._validate_internal()

    def to_static_scene(self) -> StaticScene:
        if self.is_dynamic():
            raise ValueError("dynamic Scene cannot be converted to StaticScene")
        return StaticScene(room=self.room, sources=self.sources, mics=self.mics)

    def to_dynamic_scene(self) -> DynamicScene:
        if not self.is_dynamic() or self.src_traj is None or self.mic_traj is None:
            raise ValueError("static Scene cannot be converted to DynamicScene")
        return DynamicScene(
            room=self.room,
            sources=self.sources,
            mics=self.mics,
            src_traj=self.src_traj,
            mic_traj=self.mic_traj,
        )

Source dataclass

Source container with positions and optional orientation.

Examples:

sources = Source.from_positions([[1.0, 2.0, 1.5]])
Source code in src/torchrir/models/room.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
@dataclass(frozen=True)
class Source:
    """Source container with positions and optional orientation.

    Examples:
        ```python
        sources = Source.from_positions([[1.0, 2.0, 1.5]])
        ```
    """

    positions: Tensor
    orientation: Optional[Tensor] = None

    def __post_init__(self) -> None:
        pos = _normalize_entity_positions(self.positions, name="source")
        object.__setattr__(self, "positions", pos)
        ori = _normalize_entity_orientation(
            self.orientation, n_entities=pos.shape[0], dim=pos.shape[1], name="source"
        )
        if ori is not None:
            object.__setattr__(self, "orientation", ori)

    def replace(self, **kwargs) -> "Source":
        """Return a new Source with updated fields."""
        return replace(self, **kwargs)

    @classmethod
    def from_positions(
        cls,
        positions: Sequence[Sequence[float]] | Tensor,
        *,
        orientation: Optional[Sequence[float] | Tensor] = None,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> "Source":
        """Convert positions/orientation to tensors and build a Source."""
        pos = as_tensor(positions, device=device, dtype=dtype)
        ori = None
        if orientation is not None:
            ori = as_tensor(orientation, device=device, dtype=dtype)
        return cls(pos, ori)

from_positions classmethod

from_positions(positions, *, orientation=None, device=None, dtype=None)

Convert positions/orientation to tensors and build a Source.

Source code in src/torchrir/models/room.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
@classmethod
def from_positions(
    cls,
    positions: Sequence[Sequence[float]] | Tensor,
    *,
    orientation: Optional[Sequence[float] | Tensor] = None,
    device: Optional[torch.device | str] = None,
    dtype: Optional[torch.dtype] = None,
) -> "Source":
    """Convert positions/orientation to tensors and build a Source."""
    pos = as_tensor(positions, device=device, dtype=dtype)
    ori = None
    if orientation is not None:
        ori = as_tensor(orientation, device=device, dtype=dtype)
    return cls(pos, ori)

replace

replace(**kwargs)

Return a new Source with updated fields.

Source code in src/torchrir/models/room.py
111
112
113
def replace(self, **kwargs) -> "Source":
    """Return a new Source with updated fields."""
    return replace(self, **kwargs)

StaticScene dataclass

Container for static scene simulation inputs.

Examples:

scene = StaticScene(room=room, sources=sources, mics=mics)
Source code in src/torchrir/models/scene.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@dataclass(frozen=True)
class StaticScene:
    """Container for static scene simulation inputs.

    Examples:
        ```python
        scene = StaticScene(room=room, sources=sources, mics=mics)
        ```
    """

    room: Room
    sources: Source
    mics: MicrophoneArray

    def __post_init__(self) -> None:
        _validate_scene_entities(self.room, self.sources, self.mics)

    def is_dynamic(self) -> bool:
        return False

    def validate(self) -> None:
        _validate_scene_entities(self.room, self.sources, self.mics)

torchrir.io

torchrir.io

I/O helpers for audio files and metadata serialization.

_AUDIO_BACKENDS module-attribute

_AUDIO_BACKENDS = {'soundfile': AudioBackend(name='soundfile', load=_soundfile_load, save=_soundfile_save, info=info_audio)}

_DEFAULT_AUDIO_BACKEND module-attribute

_DEFAULT_AUDIO_BACKEND = 'soundfile'

__all__ module-attribute

__all__ = ['AudioData', 'AudioBackend', 'build_metadata', 'get_audio_backend', 'info', 'info_audio', 'info_wav', 'list_audio_backends', 'load', 'load_audio', 'load_audio_data', 'load_wav', 'save_scene_audio', 'save_scene_metadata', 'save_audio', 'save_audio_data', 'save_metadata_json', 'save', 'save_wav', 'set_audio_backend']

AudioBackend dataclass

Audio I/O backend definition.

Source code in src/torchrir/io/__init__.py
25
26
27
28
29
30
31
32
@dataclass(frozen=True)
class AudioBackend:
    """Audio I/O backend definition."""

    name: str
    load: Callable[[Path, str], Tuple[Tensor, int]]
    save: Callable[[Path, Tensor, int, bool, float, str | None], None]
    info: Callable[[Path], AudioInfo]

AudioData dataclass

Audio payload with metadata needed for explicit I/O round trips.

Source code in src/torchrir/io/audio.py
13
14
15
16
17
18
19
20
@dataclass(frozen=True)
class AudioData:
    """Audio payload with metadata needed for explicit I/O round trips."""

    audio: torch.Tensor
    sample_rate: int
    format: Optional[str] = None
    subtype: Optional[str] = None

AudioInfo dataclass

Basic audio file metadata.

Source code in src/torchrir/io/audio.py
23
24
25
26
27
28
29
30
31
32
@dataclass(frozen=True)
class AudioInfo:
    """Basic audio file metadata."""

    sample_rate: int
    num_frames: int
    num_channels: int
    format: str
    subtype: str
    duration: float

_info_audio_file

_info_audio_file(path)

Return metadata for an audio file (wav/flac/other supported by soundfile).

Source code in src/torchrir/io/audio.py
206
207
208
209
210
211
212
213
214
215
216
217
218
def info_audio(path: Path) -> AudioInfo:
    """Return metadata for an audio file (wav/flac/other supported by soundfile)."""
    import soundfile as sf

    info = sf.info(str(path))
    return AudioInfo(
        sample_rate=info.samplerate,
        num_frames=info.frames,
        num_channels=info.channels,
        format=info.format,
        subtype=info.subtype,
        duration=float(info.duration),
    )

_load_audio

_load_audio(path, *, caller)

Load an audio file and return mono audio and sample rate.

Source code in src/torchrir/io/audio.py
58
59
60
61
def _load_audio(path: Path, *, caller: str) -> Tuple[torch.Tensor, int]:
    """Load an audio file and return mono audio and sample rate."""
    data = _load_audio_data(path, caller=caller)
    return data.audio, data.sample_rate

_normalize_format

_normalize_format(path, fmt)
Source code in src/torchrir/io/__init__.py
100
101
102
103
104
105
106
107
def _normalize_format(path: Path, fmt: str | None) -> str:
    fmt = (fmt or path.suffix.lstrip(".")).lower()
    if not fmt:
        raise ValueError(
            "Audio format could not be inferred from the path. "
            "Pass format='wav' or use torchrir.io.audio.load_audio/save_audio."
        )
    return fmt

_resolve_backend

_resolve_backend(name)
Source code in src/torchrir/io/__init__.py
91
92
93
94
95
96
97
def _resolve_backend(name: str | None) -> AudioBackend:
    backend_name = name or _DEFAULT_AUDIO_BACKEND
    if backend_name not in _AUDIO_BACKENDS:
        raise ValueError(
            f"Unknown audio backend '{backend_name}'. Available: {sorted(_AUDIO_BACKENDS)}"
        )
    return _AUDIO_BACKENDS[backend_name]

_save_audio

_save_audio(path, audio, sample_rate, *, normalize=True, peak=1.0, subtype=None)

Save a mono or multi-channel audio file to disk.

Source code in src/torchrir/io/audio.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def _save_audio(
    path: Path,
    audio: torch.Tensor,
    sample_rate: int,
    *,
    normalize: bool = True,
    peak: float = 1.0,
    subtype: str | None = None,
) -> None:
    """Save a mono or multi-channel audio file to disk."""
    import soundfile as sf

    audio = audio.detach().cpu().to(torch.float32)
    if normalize:
        if peak <= 0:
            raise ValueError("peak must be positive when normalize=True")
        max_val = float(audio.abs().max().item()) if audio.numel() else 0.0
        if max_val > 0:
            audio = audio / max_val * peak
    if audio.ndim == 2 and audio.shape[0] <= 8:
        audio = audio.transpose(0, 1)
    if subtype is None:
        # Backward-compatible fallback for tensors that carry custom attrs.
        subtype = getattr(audio, "_torchrir_subtype", None)
    sf.write(str(path), audio.numpy(), sample_rate, subtype=subtype)

_soundfile_load

_soundfile_load(path, caller)
Source code in src/torchrir/io/__init__.py
35
36
def _soundfile_load(path: Path, caller: str) -> Tuple[Tensor, int]:
    return _load_audio(path, caller=caller)

_soundfile_save

_soundfile_save(path, audio, sample_rate, normalize, peak, subtype)
Source code in src/torchrir/io/__init__.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def _soundfile_save(
    path: Path,
    audio: Tensor,
    sample_rate: int,
    normalize: bool,
    peak: float,
    subtype: str | None,
) -> None:
    _save_audio(
        path,
        audio,
        sample_rate,
        normalize=normalize,
        peak=peak,
        subtype=subtype,
    )

build_metadata

build_metadata(*, room, sources, mics, rirs, src_traj=None, mic_traj=None, timestamps=None, signal_len=None, source_info=None, extra=None)

Build JSON-serializable metadata for a simulation output.

Examples:

metadata = build_metadata(
    room=room,
    sources=sources,
    mics=mics,
    rirs=rirs,
    src_traj=src_traj,
    mic_traj=mic_traj,
    signal_len=signal.shape[-1],
)
save_metadata_json(Path("outputs/scene_metadata.json"), metadata)
Source code in src/torchrir/io/metadata.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def build_metadata(
    *,
    room: Room,
    sources: Source,
    mics: MicrophoneArray,
    rirs: Tensor,
    src_traj: Optional[Tensor] = None,
    mic_traj: Optional[Tensor] = None,
    timestamps: Optional[Tensor] = None,
    signal_len: Optional[int] = None,
    source_info: Optional[Any] = None,
    extra: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
    """Build JSON-serializable metadata for a simulation output.

    Examples:
        ```python
        metadata = build_metadata(
            room=room,
            sources=sources,
            mics=mics,
            rirs=rirs,
            src_traj=src_traj,
            mic_traj=mic_traj,
            signal_len=signal.shape[-1],
        )
        save_metadata_json(Path(\"outputs/scene_metadata.json\"), metadata)
        ```
    """
    nsample = int(rirs.shape[-1])
    fs = float(room.fs)
    time_axis = {
        "fs": fs,
        "nsample": nsample,
        "t": _to_serializable(torch.arange(nsample, dtype=torch.float32) / fs),
    }

    src_pos = sources.positions
    mic_pos = mics.positions
    dim = int(room.size.numel())
    src_traj_n = _normalize_traj(src_traj, src_pos, dim, "src_traj")
    mic_traj_n = _normalize_traj(mic_traj, mic_pos, dim, "mic_traj")

    t_steps = max(src_traj_n.shape[0], mic_traj_n.shape[0])
    if src_traj_n.shape[0] == 1 and t_steps > 1:
        src_traj_n = src_traj_n.expand(t_steps, -1, -1)
    if mic_traj_n.shape[0] == 1 and t_steps > 1:
        mic_traj_n = mic_traj_n.expand(t_steps, -1, -1)
    if src_traj_n.shape[0] != mic_traj_n.shape[0]:
        raise ValueError("src_traj and mic_traj must have matching time steps")

    azimuth, elevation = _compute_doa(src_traj_n, mic_traj_n)
    doa = {
        "frame": "world",
        "unit": "radians",
        "azimuth": _to_serializable(azimuth),
        "elevation": _to_serializable(elevation),
    }

    timestamps_out: Optional[Tensor] = None
    if timestamps is not None:
        timestamps_out = timestamps
    elif t_steps > 1 and signal_len is not None:
        duration = max(0.0, (float(signal_len) - 1.0) / fs)
        timestamps_out = torch.linspace(0.0, duration, t_steps, dtype=torch.float32)

    array_attrs = _array_attributes(mics)

    metadata: Dict[str, Any] = {
        "room": {
            "size": _to_serializable(room.size),
            "c": float(room.c),
            "beta": _to_serializable(room.beta) if room.beta is not None else None,
            "t60": float(room.t60) if room.t60 is not None else None,
            "fs": fs,
        },
        "sources": {
            "positions": _to_serializable(src_pos),
            "orientation": _to_serializable(sources.orientation),
        },
        "mics": {
            "positions": _to_serializable(mic_pos),
            "orientation": _to_serializable(mics.orientation),
        },
        "trajectories": {
            "sources": _to_serializable(src_traj_n if t_steps > 1 else None),
            "mics": _to_serializable(mic_traj_n if t_steps > 1 else None),
        },
        "array": {
            "geometry": array_attrs.geometry_name,
            "positions": _to_serializable(array_attrs.positions),
            "orientation": _to_serializable(array_attrs.orientation),
            "center": _to_serializable(array_attrs.center),
            "normal": _to_serializable(array_attrs.normal),
            "spacing": array_attrs.spacing,
        },
        "time_axis": time_axis,
        "doa": doa,
        "timestamps": _to_serializable(timestamps_out),
        "rirs_shape": list(rirs.shape),
        "dynamic": bool(t_steps > 1),
    }

    if source_info is not None:
        metadata["source_info"] = _to_serializable(source_info)
    if extra:
        metadata["extra"] = _to_serializable(extra)
    return metadata

get_audio_backend

get_audio_backend()

Return the current default audio backend.

Source code in src/torchrir/io/__init__.py
74
75
76
77
def get_audio_backend() -> str:
    """Return the current default audio backend."""

    return _DEFAULT_AUDIO_BACKEND

info

info(path, *, backend=None, format=None)

Deprecated wav-only metadata lookup. Use info_wav or info_audio.

Source code in src/torchrir/io/__init__.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def info(
    path: Path,
    *,
    backend: str | None = None,
    format: str | None = None,
) -> AudioInfo:
    """Deprecated wav-only metadata lookup. Use `info_wav` or `info_audio`."""

    warnings.warn(
        "torchrir.io.info is deprecated. Use torchrir.io.info_wav or torchrir.io.info_audio.",
        DeprecationWarning,
        stacklevel=2,
    )
    return info_wav(path, backend=backend, format=format)

info_audio

info_audio(path, *, backend=None)

Return metadata for an audio file in any backend-supported format.

Source code in src/torchrir/io/__init__.py
208
209
210
211
212
213
214
215
216
def info_audio(
    path: Path,
    *,
    backend: str | None = None,
) -> AudioInfo:
    """Return metadata for an audio file in any backend-supported format."""

    backend_impl = _resolve_backend(backend)
    return backend_impl.info(path)

info_wav

info_wav(path, *, backend=None, format=None)

Return metadata for a wav file.

This entry point is wav-only. For non-wav formats, use torchrir.io.audio.info_audio.

Source code in src/torchrir/io/__init__.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def info_wav(
    path: Path,
    *,
    backend: str | None = None,
    format: str | None = None,
) -> AudioInfo:
    """Return metadata for a wav file.

    This entry point is wav-only. For non-wav formats, use
    ``torchrir.io.audio.info_audio``.
    """

    fmt = _normalize_format(path, format)
    if fmt not in {"wav", "wave"}:
        raise ValueError(
            f"info expects a wav file, got format '{fmt}'. "
            "Use torchrir.io.audio.info_audio for non-wav formats."
        )
    backend_impl = _resolve_backend(backend)
    return backend_impl.info(path)

list_audio_backends

list_audio_backends()

Return the available audio backends.

Source code in src/torchrir/io/__init__.py
68
69
70
71
def list_audio_backends() -> list[str]:
    """Return the available audio backends."""

    return sorted(_AUDIO_BACKENDS.keys())

load

load(path, *, backend=None, format=None)

Deprecated wav-only loader. Use load_wav or load_audio.

Source code in src/torchrir/io/__init__.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def load(
    path: Path,
    *,
    backend: str | None = None,
    format: str | None = None,
) -> Tuple[Tensor, int]:
    """Deprecated wav-only loader. Use `load_wav` or `load_audio`."""

    warnings.warn(
        "torchrir.io.load is deprecated. Use torchrir.io.load_wav or torchrir.io.load_audio.",
        DeprecationWarning,
        stacklevel=2,
    )
    return load_wav(path, backend=backend, format=format)

load_audio

load_audio(path, *, backend=None)

Load an audio file in any format supported by the backend.

Source code in src/torchrir/io/__init__.py
181
182
183
184
185
186
187
188
189
def load_audio(
    path: Path,
    *,
    backend: str | None = None,
) -> Tuple[Tensor, int]:
    """Load an audio file in any format supported by the backend."""

    backend_impl = _resolve_backend(backend)
    return backend_impl.load(path, "load_audio")

load_audio_data

load_audio_data(path)

Load an audio file and return audio + metadata in a stable container.

Source code in src/torchrir/io/audio.py
64
65
66
def load_audio_data(path: Path) -> AudioData:
    """Load an audio file and return audio + metadata in a stable container."""
    return _load_audio_data(path, caller="load_audio_data")

load_wav

load_wav(path, *, backend=None, format=None)

Load a wav file and return mono audio and sample rate.

This entry point is wav-only. For non-wav formats, use torchrir.io.audio.load_audio.

Source code in src/torchrir/io/__init__.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def load_wav(
    path: Path,
    *,
    backend: str | None = None,
    format: str | None = None,
) -> Tuple[Tensor, int]:
    """Load a wav file and return mono audio and sample rate.

    This entry point is wav-only. For non-wav formats, use
    ``torchrir.io.audio.load_audio``.
    """

    fmt = _normalize_format(path, format)
    if fmt not in {"wav", "wave"}:
        raise ValueError(
            f"load expects a wav file, got format '{fmt}'. "
            "Use torchrir.io.audio.load_audio for non-wav formats."
        )
    backend_impl = _resolve_backend(backend)
    return backend_impl.load(path, "load")

save

save(path, audio, sample_rate, *, backend=None, format=None, normalize=True, peak=1.0, subtype=None)

Deprecated wav-only saver. Use save_wav or save_audio.

Source code in src/torchrir/io/__init__.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def save(
    path: Path,
    audio: Tensor,
    sample_rate: int,
    *,
    backend: str | None = None,
    format: str | None = None,
    normalize: bool = True,
    peak: float = 1.0,
    subtype: str | None = None,
) -> None:
    """Deprecated wav-only saver. Use `save_wav` or `save_audio`."""

    warnings.warn(
        "torchrir.io.save is deprecated. Use torchrir.io.save_wav or torchrir.io.save_audio.",
        DeprecationWarning,
        stacklevel=2,
    )
    save_wav(
        path,
        audio,
        sample_rate,
        backend=backend,
        format=format,
        normalize=normalize,
        peak=peak,
        subtype=subtype,
    )

save_attribution_file

save_attribution_file(*, out_dir, dataset_attribution, modifications, attribution_name='ATTRIBUTION.txt', logger=None)

Save dataset attribution and modification notes to a text file.

Source code in src/torchrir/io/outputs.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def save_attribution_file(
    *,
    out_dir: Path,
    dataset_attribution: Mapping[str, Any] | Any,
    modifications: list[str],
    attribution_name: str = "ATTRIBUTION.txt",
    logger: Optional[logging.Logger] = None,
) -> Path:
    """Save dataset attribution and modification notes to a text file."""
    out_dir.mkdir(parents=True, exist_ok=True)
    info = _coerce_attribution_mapping(dataset_attribution)

    lines = [
        "TorchRIR Dataset Attribution",
        "",
        "This directory contains derived audio generated with TorchRIR.",
        "",
        f"Dataset: {info['dataset']}",
        f"Source: {info['source']}",
        f"License: {info['license_name']}",
        f"License URL: {info['license_url']}",
        f"Required attribution: {info['required_attribution']}",
    ]
    subset = info.get("subset")
    if subset is not None:
        lines.append(f"Subset: {subset}")
    lines.extend(
        [
            "",
            "Modifications applied in this output:",
            *[f"- {note}" for note in modifications],
            "",
            "When redistributing these derived files, keep this attribution file",
            "and include the upstream dataset license terms.",
            "",
            "See repository notice: THIRD_PARTY_DATASETS.md",
        ]
    )
    out_path = out_dir / attribution_name
    out_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
    if logger is not None:
        logger.info("saved: %s", out_path)
    return out_path

save_audio

save_audio(path, audio, sample_rate, *, backend=None, normalize=True, peak=1.0, subtype=None)

Save an audio file in any format supported by the backend.

Source code in src/torchrir/io/__init__.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def save_audio(
    path: Path,
    audio: Tensor,
    sample_rate: int,
    *,
    backend: str | None = None,
    normalize: bool = True,
    peak: float = 1.0,
    subtype: str | None = None,
) -> None:
    """Save an audio file in any format supported by the backend."""

    backend_impl = _resolve_backend(backend)
    backend_impl.save(path, audio, sample_rate, normalize, peak, subtype)

save_audio_data

save_audio_data(path, data, *, normalize=True, peak=1.0, subtype=None)

Save audio from AudioData, optionally preserving its stored subtype.

Source code in src/torchrir/io/audio.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def save_audio_data(
    path: Path,
    data: AudioData,
    *,
    normalize: bool = True,
    peak: float = 1.0,
    subtype: str | None = None,
) -> None:
    """Save audio from AudioData, optionally preserving its stored subtype."""
    _save_audio(
        path,
        data.audio,
        data.sample_rate,
        normalize=normalize,
        peak=peak,
        subtype=data.subtype if subtype is None else subtype,
    )

save_metadata_json

save_metadata_json(path, metadata)

Save metadata as JSON to the given path.

Examples:

save_metadata_json(Path("outputs/scene_metadata.json"), metadata)
Source code in src/torchrir/io/metadata.py
138
139
140
141
142
143
144
145
146
147
148
def save_metadata_json(path: Path, metadata: Dict[str, Any]) -> None:
    """Save metadata as JSON to the given path.

    Examples:
        ```python
        save_metadata_json(Path(\"outputs/scene_metadata.json\"), metadata)
        ```
    """
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        json.dump(metadata, f, indent=2)

save_scene_audio

save_scene_audio(*, out_dir, audio, fs, audio_name, logger=None)

Save scene audio to the output directory.

Source code in src/torchrir/io/outputs.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def save_scene_audio(
    *,
    out_dir: Path,
    audio: Tensor,
    fs: int,
    audio_name: str,
    logger: Optional[logging.Logger] = None,
) -> Path:
    """Save scene audio to the output directory."""
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / audio_name
    save(out_path, audio, fs)
    if logger is not None:
        logger.info("saved: %s", out_path)
    return out_path

save_scene_metadata

save_scene_metadata(*, out_dir, metadata_name, room, sources, mics, rirs, src_traj=None, mic_traj=None, timestamps=None, signal_len=None, source_info=None, extra=None, logger=None)

Build and save scene metadata JSON to the output directory.

Source code in src/torchrir/io/outputs.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def save_scene_metadata(
    *,
    out_dir: Path,
    metadata_name: str,
    room: "Room",
    sources: "Source",
    mics: "MicrophoneArray",
    rirs: Tensor,
    src_traj: Optional[Tensor] = None,
    mic_traj: Optional[Tensor] = None,
    timestamps: Optional[Tensor] = None,
    signal_len: Optional[int] = None,
    source_info: Optional[Any] = None,
    extra: Optional[dict[str, Any]] = None,
    logger: Optional[logging.Logger] = None,
) -> dict[str, Any]:
    """Build and save scene metadata JSON to the output directory."""
    out_dir.mkdir(parents=True, exist_ok=True)
    metadata = build_metadata(
        room=room,
        sources=sources,
        mics=mics,
        rirs=rirs,
        src_traj=src_traj,
        mic_traj=mic_traj,
        timestamps=timestamps,
        signal_len=signal_len,
        source_info=source_info,
        extra=extra,
    )
    meta_path = out_dir / metadata_name
    save_metadata_json(meta_path, metadata)
    if logger is not None:
        logger.info("saved: %s", meta_path)
    return metadata

save_wav

save_wav(path, audio, sample_rate, *, backend=None, format=None, normalize=True, peak=1.0, subtype=None)

Save a wav file to disk.

This entry point is wav-only. For non-wav formats, use torchrir.io.audio.save_audio.

Source code in src/torchrir/io/__init__.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def save_wav(
    path: Path,
    audio: Tensor,
    sample_rate: int,
    *,
    backend: str | None = None,
    format: str | None = None,
    normalize: bool = True,
    peak: float = 1.0,
    subtype: str | None = None,
) -> None:
    """Save a wav file to disk.

    This entry point is wav-only. For non-wav formats, use
    ``torchrir.io.audio.save_audio``.
    """

    fmt = _normalize_format(path, format)
    if fmt not in {"wav", "wave"}:
        raise ValueError(
            f"save expects a wav file, got format '{fmt}'. "
            "Use torchrir.io.audio.save_audio for non-wav formats."
        )
    backend_impl = _resolve_backend(backend)
    backend_impl.save(path, audio, sample_rate, normalize, peak, subtype)

set_audio_backend

set_audio_backend(name)

Set the default audio backend.

Source code in src/torchrir/io/__init__.py
80
81
82
83
84
85
86
87
88
def set_audio_backend(name: str) -> None:
    """Set the default audio backend."""

    if name not in _AUDIO_BACKENDS:
        raise ValueError(
            f"Unknown audio backend '{name}'. Available: {sorted(_AUDIO_BACKENDS)}"
        )
    global _DEFAULT_AUDIO_BACKEND
    _DEFAULT_AUDIO_BACKEND = name

torchrir.logging

torchrir.logging

Logging helpers for torchrir.

__all__ module-attribute

__all__ = ['LoggingConfig', 'get_logger', 'setup_logging']

LoggingConfig dataclass

Configuration for torchrir logging.

Examples:

config = LoggingConfig(level="INFO")
logger = setup_logging(config)
Source code in src/torchrir/logging.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@dataclass(frozen=True)
class LoggingConfig:
    """Configuration for torchrir logging.

    Examples:
        ```python
        config = LoggingConfig(level="INFO")
        logger = setup_logging(config)
        ```
    """

    level: str | int = "INFO"
    format: str = "%(levelname)s:%(name)s:%(message)s"
    datefmt: Optional[str] = None
    propagate: bool = False

    def resolve_level(self) -> int:
        """Resolve level to a logging integer constant."""
        if isinstance(self.level, int):
            return self.level
        if not isinstance(self.level, str):
            raise TypeError("level must be str or int")
        name = self.level.upper()
        if name not in logging._nameToLevel:
            raise ValueError(f"unknown log level: {self.level}")
        return logging._nameToLevel[name]

    def replace(self, **kwargs) -> "LoggingConfig":
        """Return a new config with updated fields."""
        return replace(self, **kwargs)

replace

replace(**kwargs)

Return a new config with updated fields.

Source code in src/torchrir/logging.py
37
38
39
def replace(self, **kwargs) -> "LoggingConfig":
    """Return a new config with updated fields."""
    return replace(self, **kwargs)

resolve_level

resolve_level()

Resolve level to a logging integer constant.

Source code in src/torchrir/logging.py
26
27
28
29
30
31
32
33
34
35
def resolve_level(self) -> int:
    """Resolve level to a logging integer constant."""
    if isinstance(self.level, int):
        return self.level
    if not isinstance(self.level, str):
        raise TypeError("level must be str or int")
    name = self.level.upper()
    if name not in logging._nameToLevel:
        raise ValueError(f"unknown log level: {self.level}")
    return logging._nameToLevel[name]

get_logger

get_logger(name=None)

Return a torchrir logger, namespaced under the torchrir root.

Examples:

logger = get_logger("examples.static")
Source code in src/torchrir/logging.py
63
64
65
66
67
68
69
70
71
72
73
74
75
def get_logger(name: Optional[str] = None) -> logging.Logger:
    """Return a torchrir logger, namespaced under the torchrir root.

    Examples:
        ```python
        logger = get_logger("examples.static")
        ```
    """
    if not name:
        return logging.getLogger("torchrir")
    if name.startswith("torchrir"):
        return logging.getLogger(name)
    return logging.getLogger(f"torchrir.{name}")

setup_logging

setup_logging(config, *, name='torchrir')

Configure and return the base torchrir logger.

Examples:

logger = setup_logging(LoggingConfig(level="DEBUG"))
logger.info("ready")
Source code in src/torchrir/logging.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def setup_logging(config: LoggingConfig, *, name: str = "torchrir") -> logging.Logger:
    """Configure and return the base torchrir logger.

    Examples:
        ```python
        logger = setup_logging(LoggingConfig(level="DEBUG"))
        logger.info("ready")
        ```
    """
    logger = logging.getLogger(name)
    level = config.resolve_level()
    logger.setLevel(level)
    logger.propagate = config.propagate
    if not logger.handlers:
        handler = logging.StreamHandler()
        handler.setLevel(level)
        handler.setFormatter(logging.Formatter(config.format, datefmt=config.datefmt))
        logger.addHandler(handler)
    return logger

torchrir.config

torchrir.config

Configuration objects for torchrir.

__all__ module-attribute

__all__ = ['SimulationConfig', 'default_config']

SimulationConfig dataclass

Configuration values for RIR simulation and convolution.

Examples:

cfg = SimulationConfig(max_order=6, tmax=0.3, device="auto")
cfg.validate()
Source code in src/torchrir/config.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@dataclass(frozen=True)
class SimulationConfig:
    """Configuration values for RIR simulation and convolution.

    Examples:
        ```python
        cfg = SimulationConfig(max_order=6, tmax=0.3, device="auto")
        cfg.validate()
        ```
    """

    fs: Optional[float] = None
    max_order: Optional[int] = None
    tmax: Optional[float] = None
    directivity: Optional[str | tuple[str, str]] = None
    device: Optional[torch.device | str] = None
    seed: Optional[int] = None
    use_lut: bool = True
    mixed_precision: bool = False
    frac_delay_length: int = 81
    sinc_lut_granularity: int = 20
    image_chunk_size: int = 2048
    accumulate_chunk_size: int = 4096
    use_compile: bool = False
    rir_hpf_enable: bool = True
    rir_hpf_fc: float = 10.0
    rir_hpf_kwargs: dict[str, float | int | str] = field(
        default_factory=lambda: {"n": 2, "rp": 5.0, "rs": 60.0, "type": "butter"}
    )

    def validate(self) -> None:
        """Validate configuration values."""
        if self.fs is not None and self.fs <= 0:
            raise ValueError("fs must be positive")
        if self.max_order is not None and self.max_order < 0:
            raise ValueError("max_order must be non-negative")
        if self.tmax is not None and self.tmax <= 0:
            raise ValueError("tmax must be positive")
        if self.seed is not None and self.seed < 0:
            raise ValueError("seed must be non-negative")
        if self.frac_delay_length <= 0 or self.frac_delay_length % 2 == 0:
            raise ValueError("frac_delay_length must be a positive odd integer")
        if self.sinc_lut_granularity <= 0:
            raise ValueError("sinc_lut_granularity must be positive")
        if self.image_chunk_size <= 0:
            raise ValueError("image_chunk_size must be positive")
        if self.accumulate_chunk_size <= 0:
            raise ValueError("accumulate_chunk_size must be positive")
        if self.rir_hpf_fc <= 0:
            raise ValueError("rir_hpf_fc must be positive")
        if "n" in self.rir_hpf_kwargs and int(self.rir_hpf_kwargs["n"]) <= 0:
            raise ValueError("rir_hpf_kwargs['n'] must be positive")

    def replace(self, **kwargs) -> "SimulationConfig":
        """Return a new config with updated fields."""
        new_cfg = replace(self, **kwargs)
        new_cfg.validate()
        return new_cfg

replace

replace(**kwargs)

Return a new config with updated fields.

Source code in src/torchrir/config.py
64
65
66
67
68
def replace(self, **kwargs) -> "SimulationConfig":
    """Return a new config with updated fields."""
    new_cfg = replace(self, **kwargs)
    new_cfg.validate()
    return new_cfg

validate

validate()

Validate configuration values.

Source code in src/torchrir/config.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def validate(self) -> None:
    """Validate configuration values."""
    if self.fs is not None and self.fs <= 0:
        raise ValueError("fs must be positive")
    if self.max_order is not None and self.max_order < 0:
        raise ValueError("max_order must be non-negative")
    if self.tmax is not None and self.tmax <= 0:
        raise ValueError("tmax must be positive")
    if self.seed is not None and self.seed < 0:
        raise ValueError("seed must be non-negative")
    if self.frac_delay_length <= 0 or self.frac_delay_length % 2 == 0:
        raise ValueError("frac_delay_length must be a positive odd integer")
    if self.sinc_lut_granularity <= 0:
        raise ValueError("sinc_lut_granularity must be positive")
    if self.image_chunk_size <= 0:
        raise ValueError("image_chunk_size must be positive")
    if self.accumulate_chunk_size <= 0:
        raise ValueError("accumulate_chunk_size must be positive")
    if self.rir_hpf_fc <= 0:
        raise ValueError("rir_hpf_fc must be positive")
    if "n" in self.rir_hpf_kwargs and int(self.rir_hpf_kwargs["n"]) <= 0:
        raise ValueError("rir_hpf_kwargs['n'] must be positive")

default_config

default_config()

Return the default simulation configuration.

Examples:

cfg = default_config()
Source code in src/torchrir/config.py
71
72
73
74
75
76
77
78
79
80
81
def default_config() -> SimulationConfig:
    """Return the default simulation configuration.

    Examples:
        ```python
        cfg = default_config()
        ```
    """
    cfg = SimulationConfig()
    cfg.validate()
    return cfg

torchrir.util

torchrir.util

General-purpose math, device, and tensor utilities for torchrir.

__all__ module-attribute

__all__ = ['DeviceSpec', 'add_output_args', 'as_tensor', 'attenuation_db_to_time_sabine', 'ensure_dim', 'estimate_beta_from_t60', 'estimate_image_counts_from_tmax', 'estimate_t60_from_beta', 'extend_size', 'infer_device_dtype', 'normalize_orientation', 'orientation_to_unit', 'resolve_device']

DeviceSpec dataclass

Resolve device + dtype defaults consistently.

Examples:

spec = DeviceSpec(device="auto", dtype=torch.float32)
device, dtype = spec.resolve(tensor)
Source code in src/torchrir/util/device.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@dataclass(frozen=True)
class DeviceSpec:
    """Resolve device + dtype defaults consistently.

    Examples:
        ```python
        spec = DeviceSpec(device="auto", dtype=torch.float32)
        device, dtype = spec.resolve(tensor)
        ```
    """

    device: Optional[torch.device | str] = None
    dtype: Optional[torch.dtype] = None
    prefer: Tuple[str, ...] = ("cuda", "mps", "cpu")

    def resolve(self, *values) -> Tuple[torch.device, torch.dtype]:
        """Resolve device/dtype from inputs with overrides."""
        tensor_device: Optional[torch.device] = None
        tensor_dtype: Optional[torch.dtype] = None
        for value in values:
            if torch.is_tensor(value):
                if tensor_device is None:
                    tensor_device = value.device
                if tensor_dtype is None:
                    tensor_dtype = value.dtype

        if isinstance(self.device, str) and self.device.lower() == "auto":
            device = tensor_device or resolve_device("auto", prefer=self.prefer)
        elif self.device is None:
            device = tensor_device or torch.device("cpu")
        else:
            device = resolve_device(self.device, prefer=self.prefer)

        if self.dtype is None:
            dtype = tensor_dtype or torch.float32
        else:
            dtype = self.dtype
        return device, dtype

resolve

resolve(*values)

Resolve device/dtype from inputs with overrides.

Source code in src/torchrir/util/device.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def resolve(self, *values) -> Tuple[torch.device, torch.dtype]:
    """Resolve device/dtype from inputs with overrides."""
    tensor_device: Optional[torch.device] = None
    tensor_dtype: Optional[torch.dtype] = None
    for value in values:
        if torch.is_tensor(value):
            if tensor_device is None:
                tensor_device = value.device
            if tensor_dtype is None:
                tensor_dtype = value.dtype

    if isinstance(self.device, str) and self.device.lower() == "auto":
        device = tensor_device or resolve_device("auto", prefer=self.prefer)
    elif self.device is None:
        device = tensor_device or torch.device("cpu")
    else:
        device = resolve_device(self.device, prefer=self.prefer)

    if self.dtype is None:
        dtype = tensor_dtype or torch.float32
    else:
        dtype = self.dtype
    return device, dtype

add_output_args

add_output_args(parser, *, out_dir_default, plot_default=False, include_plot=True, include_show=True, include_gif=False)

Add common output/plot/GIF arguments to a parser.

Source code in src/torchrir/util/cli.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def add_output_args(
    parser: argparse.ArgumentParser,
    *,
    out_dir_default: str | Path,
    plot_default: bool = False,
    include_plot: bool = True,
    include_show: bool = True,
    include_gif: bool = False,
) -> None:
    """Add common output/plot/GIF arguments to a parser."""
    parser.add_argument(
        "--out-dir",
        type=Path,
        default=Path(out_dir_default),
        help="Output directory for WAV/metadata/plots/GIFs.",
    )
    if include_plot:
        parser.add_argument(
            "--plot",
            action="store_true",
            default=plot_default,
            help="Plot room layout and trajectories."
            if not plot_default
            else "Plot outputs (PNG).",
        )
        if plot_default:
            parser.add_argument(
                "--no-plot",
                action="store_false",
                dest="plot",
                help="Disable plotting.",
            )
    if include_show:
        parser.add_argument(
            "--show", action="store_true", help="show plots interactively"
        )
    if include_gif:
        parser.add_argument(
            "--gif", action="store_true", help="Save trajectory animation GIF."
        )
        parser.add_argument(
            "--gif-fps",
            type=int,
            default=-1,
            help="GIF frames per second (<=0 uses auto).",
        )

as_tensor

as_tensor(value, *, device=None, dtype=None)

Convert a value to a tensor while preserving device/dtype when possible.

Source code in src/torchrir/util/tensor.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def as_tensor(
    value: Tensor | Iterable[float] | Iterable[Iterable[float]] | float | int,
    *,
    device: Optional[torch.device | str] = None,
    dtype: Optional[torch.dtype] = None,
) -> Tensor:
    """Convert a value to a tensor while preserving device/dtype when possible."""
    if isinstance(device, str):
        device = resolve_device(device)
    if torch.is_tensor(value):
        out = value
        if device is not None:
            out = out.to(device)
        if dtype is not None:
            out = out.to(dtype)
        return out
    return torch.as_tensor(value, device=device, dtype=dtype)

attenuation_db_to_time_sabine

attenuation_db_to_time_sabine(att_db, t60)

Convert attenuation (dB) to time based on T60.

Note

This function corresponds to gpuRIR's att2t_SabineEstimation. TorchRIR uses snake_case naming for consistency.

Examples:

t = attenuation_db_to_time_sabine(att_db=60.0, t60=0.4)
Source code in src/torchrir/util/acoustics.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def attenuation_db_to_time_sabine(att_db: float, t60: float) -> float:
    """Convert attenuation (dB) to time based on T60.

    Note:
        This function corresponds to gpuRIR's ``att2t_SabineEstimation``. TorchRIR
        uses snake_case naming for consistency.

    Examples:
        ```python
        t = attenuation_db_to_time_sabine(att_db=60.0, t60=0.4)
        ```
    """
    if t60 <= 0:
        raise ValueError("t60 must be positive")
    if att_db <= 0:
        raise ValueError("att_db must be positive")
    return (att_db / 60.0) * t60

ensure_dim

ensure_dim(size)

Validate room size dimensionality (2D or 3D).

Source code in src/torchrir/util/tensor.py
32
33
34
35
36
def ensure_dim(size: Tensor) -> Tensor:
    """Validate room size dimensionality (2D or 3D)."""
    if size.ndim != 1 or size.numel() not in (2, 3):
        raise ValueError("room size must be a 1D tensor of length 2 or 3")
    return size

estimate_beta_from_t60

estimate_beta_from_t60(size, t60, *, device=None, dtype=None)

Estimate reflection coefficients from T60 using Sabine's formula.

Note

This function corresponds to gpuRIR's beta_SabineEstimation. TorchRIR uses snake_case naming for consistency.

Examples:

beta = estimate_beta_from_t60(torch.tensor([6.0, 4.0, 3.0]), t60=0.4)
Source code in src/torchrir/util/acoustics.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def estimate_beta_from_t60(
    size: Tensor,
    t60: float,
    *,
    device: Optional[torch.device | str] = None,
    dtype: Optional[torch.dtype] = None,
) -> Tensor:
    """Estimate reflection coefficients from T60 using Sabine's formula.

    Note:
        This function corresponds to gpuRIR's ``beta_SabineEstimation``. TorchRIR
        uses snake_case naming for consistency.

    Examples:
        ```python
        beta = estimate_beta_from_t60(torch.tensor([6.0, 4.0, 3.0]), t60=0.4)
        ```
    """
    if t60 <= 0:
        raise ValueError("t60 must be positive")
    size = as_tensor(size, device=device, dtype=dtype)
    size = ensure_dim(size)
    dim = size.numel()
    if dim == 2:
        lx, ly = size.tolist()
        lz = 1.0
        volume = lx * ly * lz
        surface = 2.0 * (lx + ly) * lz
        alpha = 0.161 * volume / (t60 * surface)
        alpha = max(0.0, min(alpha, 0.999))
        beta = math.sqrt(1.0 - alpha)
        return torch.full((4,), beta, device=size.device, dtype=size.dtype)
    size = extend_size(size, 3)
    lx, ly, lz = size.tolist()
    volume = lx * ly * lz
    surface = 2.0 * (lx * ly + ly * lz + lx * lz)
    alpha = 0.161 * volume / (t60 * surface)
    alpha = max(0.0, min(alpha, 0.999))
    beta = math.sqrt(1.0 - alpha)
    return torch.full((6,), beta, device=size.device, dtype=size.dtype)

estimate_image_counts_from_tmax

estimate_image_counts_from_tmax(tmax, room_size, c=_DEF_SPEED_OF_SOUND)

Estimate image counts per dimension needed to cover tmax.

Note

This function corresponds to gpuRIR's t2n helper, renamed for clarity.

Examples:

nb_img = estimate_image_counts_from_tmax(0.3, torch.tensor([6.0, 4.0, 3.0]))
Source code in src/torchrir/util/acoustics.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def estimate_image_counts_from_tmax(
    tmax: float, room_size: Tensor, c: float = _DEF_SPEED_OF_SOUND
) -> Tensor:
    """Estimate image counts per dimension needed to cover tmax.

    Note:
        This function corresponds to gpuRIR's ``t2n`` helper, renamed for clarity.

    Examples:
        ```python
        nb_img = estimate_image_counts_from_tmax(0.3, torch.tensor([6.0, 4.0, 3.0]))
        ```
    """
    if tmax <= 0:
        raise ValueError("tmax must be positive")
    size = as_tensor(room_size)
    size = ensure_dim(size)
    n = torch.ceil((tmax * c) / size).to(torch.int64)
    return n

estimate_t60_from_beta

estimate_t60_from_beta(size, beta, *, device=None, dtype=None)

Estimate T60 from reflection coefficients using Sabine's formula.

Examples:

t60 = estimate_t60_from_beta(torch.tensor([6.0, 4.0, 3.0]), beta=torch.full((6,), 0.9))
Source code in src/torchrir/util/acoustics.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def estimate_t60_from_beta(
    size: Tensor,
    beta: Tensor,
    *,
    device: Optional[torch.device | str] = None,
    dtype: Optional[torch.dtype] = None,
) -> float:
    """Estimate T60 from reflection coefficients using Sabine's formula.

    Examples:
        ```python
        t60 = estimate_t60_from_beta(torch.tensor([6.0, 4.0, 3.0]), beta=torch.full((6,), 0.9))
        ```
    """
    size = as_tensor(size, device=device, dtype=dtype)
    size = ensure_dim(size)
    beta = as_tensor(beta, device=size.device, dtype=size.dtype)
    dim = size.numel()
    if dim == 2:
        if beta.numel() != 4:
            raise ValueError("beta must have 4 elements for 2D t60 estimation")
        lx, ly = size.tolist()
        lz = 1.0
        volume = lx * ly * lz
        surfaces = torch.tensor(
            [ly * lz, ly * lz, lx * lz, lx * lz],
            device=size.device,
            dtype=size.dtype,
        )
        alpha = 1.0 - beta**2
        absorption = torch.sum(surfaces * alpha).item()
        if absorption <= 0.0:
            return float("inf")
        return 0.161 * volume / absorption
    size = extend_size(size, 3)
    if beta.numel() != 6:
        raise ValueError("beta must have 6 elements for t60 estimation")
    lx, ly, lz = size.tolist()
    volume = lx * ly * lz
    surfaces = torch.tensor(
        [ly * lz, ly * lz, lx * lz, lx * lz, lx * ly, lx * ly],
        device=size.device,
        dtype=size.dtype,
    )
    alpha = 1.0 - beta**2
    absorption = torch.sum(surfaces * alpha).item()
    if absorption <= 0.0:
        return float("inf")
    return 0.161 * volume / absorption

extend_size

extend_size(size, dim)

Extend 2D room size to 3D by adding a dummy z dimension.

Source code in src/torchrir/util/tensor.py
39
40
41
42
43
44
45
46
def extend_size(size: Tensor, dim: int) -> Tensor:
    """Extend 2D room size to 3D by adding a dummy z dimension."""
    if size.numel() == dim:
        return size
    if size.numel() == 2 and dim == 3:
        pad = torch.tensor([1.0], device=size.device, dtype=size.dtype)
        return torch.cat([size, pad])
    raise ValueError("unsupported room dimension")

infer_device_dtype

infer_device_dtype(*values, device=None, dtype=None)

Infer device/dtype from inputs with optional overrides.

Source code in src/torchrir/util/device.py
 98
 99
100
101
102
103
104
def infer_device_dtype(
    *values,
    device: Optional[torch.device | str] = None,
    dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.device, torch.dtype]:
    """Infer device/dtype from inputs with optional overrides."""
    return DeviceSpec(device=device, dtype=dtype).resolve(*values)

normalize_orientation

normalize_orientation(orientation, *, eps=1e-08)

Normalize orientation vectors with numerical stability.

Source code in src/torchrir/util/orientation.py
 9
10
11
12
13
def normalize_orientation(orientation: Tensor, *, eps: float = 1e-8) -> Tensor:
    """Normalize orientation vectors with numerical stability."""
    norm = torch.linalg.norm(orientation, dim=-1, keepdim=True)
    norm = torch.clamp(norm, min=eps)
    return orientation / norm

orientation_to_unit

orientation_to_unit(orientation, dim)

Convert orientation representation to unit vectors in 2D/3D.

Source code in src/torchrir/util/orientation.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def orientation_to_unit(orientation: Tensor, dim: int) -> Tensor:
    """Convert orientation representation to unit vectors in 2D/3D."""
    if dim == 2:
        if orientation.ndim == 0:
            angle = orientation
            vec = torch.stack([torch.cos(angle), torch.sin(angle)])
            return normalize_orientation(vec)
        if orientation.shape[-1] == 1:
            angle = orientation.squeeze(-1)
            vec = torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1)
            return normalize_orientation(vec)
        if orientation.ndim == 1 and orientation.numel() != 2:
            angle = orientation
            vec = torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1)
            return normalize_orientation(vec)
        if orientation.shape[-1] == 2:
            return normalize_orientation(orientation)
        raise ValueError("2D orientation must be angle or 2D vector")
    if dim == 3:
        if orientation.shape[-1] == 3:
            return normalize_orientation(orientation)
        if orientation.shape[-1] == 2:
            azimuth = orientation[..., 0]
            elevation = orientation[..., 1]
            x = torch.cos(elevation) * torch.cos(azimuth)
            y = torch.cos(elevation) * torch.sin(azimuth)
            z = torch.sin(elevation)
            vec = torch.stack([x, y, z], dim=-1)
            return normalize_orientation(vec)
        raise ValueError("3D orientation must be vector or (azimuth, elevation)")
    raise ValueError("unsupported dimension for orientation")

resolve_device

resolve_device(device, *, prefer=('cuda', 'mps', 'cpu'))

Resolve a device string (including 'auto') into a torch.device.

Falls back to CPU when the requested backend is unavailable.

Examples:

device = resolve_device("auto")
Source code in src/torchrir/util/device.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def resolve_device(
    device: Optional[torch.device | str],
    *,
    prefer: Tuple[str, ...] = ("cuda", "mps", "cpu"),
) -> torch.device:
    """Resolve a device string (including 'auto') into a torch.device.

    Falls back to CPU when the requested backend is unavailable.

    Examples:
        ```python
        device = resolve_device("auto")
        ```
    """
    if device is None:
        return torch.device("cpu")
    if isinstance(device, torch.device):
        return device

    dev = str(device).lower()
    if dev == "auto":
        for backend in prefer:
            if backend == "cuda" and torch.cuda.is_available():
                return torch.device("cuda")
            if backend == "mps" and torch.backends.mps.is_available():
                return torch.device("mps")
            if backend == "cpu":
                return torch.device("cpu")
        return torch.device("cpu")

    if dev.startswith("cuda"):
        if torch.cuda.is_available():
            return torch.device(device)
        warnings.warn("CUDA not available; falling back to CPU.", RuntimeWarning)
        return torch.device("cpu")
    if dev == "mps":
        if torch.backends.mps.is_available():
            return torch.device("mps")
        warnings.warn("MPS not available; falling back to CPU.", RuntimeWarning)
        return torch.device("cpu")
    if dev == "cpu":
        return torch.device("cpu")

    return torch.device(device)

torchrir.datasets

torchrir.datasets

Dataset helpers for torchrir.

Includes CMU ARCTIC and LibriSpeech dataset wrappers plus collate utilities for DataLoader usage. Experimental dataset stubs live under torchrir.experimental. Use load_dataset_sources to build fixed-length source signals from random utterances. Dynamic CMU ARCTIC scene generation is available via build_dynamic_cmu_arctic_dataset.

Examples:

from torch.utils.data import DataLoader
from torchrir.datasets import CmuArcticDataset, collate_dataset_items
dataset = CmuArcticDataset("datasets/cmu_arctic", speaker="bdl", download=True)
loader = DataLoader(dataset, batch_size=4, collate_fn=collate_dataset_items)
from pathlib import Path
from torchrir.datasets import LibriSpeechDataset
librispeech = LibriSpeechDataset(Path("datasets/librispeech"), subset="train-clean-100")

__all__ module-attribute

__all__ = ['BaseDataset', 'CmuArcticDataset', 'CmuArcticSentence', 'choose_speakers', 'DatasetItem', 'DatasetAttribution', 'CollateBatch', 'default_modification_notes', 'collate_dataset_items', 'cmu_arctic_speakers', 'build_dynamic_cmu_arctic_dataset', 'attribution_for', 'SentenceLike', 'load_dataset_sources', 'load', 'save', 'LibriSpeechDataset', 'LibriSpeechSentence']

BaseDataset

Bases: Dataset[DatasetItem]

Base dataset class compatible with torch.utils.data.Dataset.

Source code in src/torchrir/datasets/base.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class BaseDataset(Dataset[DatasetItem]):
    """Base dataset class compatible with torch.utils.data.Dataset."""

    _sentences_cache: Optional[list[SentenceLike]] = None

    def list_speakers(self) -> list[str]:
        """Return available speaker IDs."""
        raise NotImplementedError

    def available_sentences(self) -> Sequence[SentenceLike]:
        """Return sentence entries that have audio available."""
        raise NotImplementedError

    def load_audio(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
        """Load audio for an utterance and return (audio, sample_rate)."""
        raise NotImplementedError

    def attribution_info(self) -> DatasetAttribution:
        """Return attribution and license information for this dataset."""
        raise NotImplementedError

    def __len__(self) -> int:
        return len(self._get_sentences())

    def __getitem__(self, idx) -> DatasetItem:  # ty: ignore[invalid-method-override]
        if not isinstance(idx, int):
            raise TypeError(f"Index must be int, got {type(idx)!r}")
        sentences = self._get_sentences()
        sentence = sentences[idx]
        audio, sample_rate = self.load_audio(sentence.utterance_id)
        speaker = getattr(self, "speaker", None)
        text = getattr(sentence, "text", None)
        return DatasetItem(
            audio=audio,
            sample_rate=sample_rate,
            utterance_id=sentence.utterance_id,
            text=text,
            speaker=speaker,
        )

    def _get_sentences(self) -> list[SentenceLike]:
        if self._sentences_cache is None:
            self._sentences_cache = list(self.available_sentences())
        return self._sentences_cache

attribution_info

attribution_info()

Return attribution and license information for this dataset.

Source code in src/torchrir/datasets/base.py
49
50
51
def attribution_info(self) -> DatasetAttribution:
    """Return attribution and license information for this dataset."""
    raise NotImplementedError

available_sentences

available_sentences()

Return sentence entries that have audio available.

Source code in src/torchrir/datasets/base.py
41
42
43
def available_sentences(self) -> Sequence[SentenceLike]:
    """Return sentence entries that have audio available."""
    raise NotImplementedError

list_speakers

list_speakers()

Return available speaker IDs.

Source code in src/torchrir/datasets/base.py
37
38
39
def list_speakers(self) -> list[str]:
    """Return available speaker IDs."""
    raise NotImplementedError

load_audio

load_audio(utterance_id)

Load audio for an utterance and return (audio, sample_rate).

Source code in src/torchrir/datasets/base.py
45
46
47
def load_audio(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
    """Load audio for an utterance and return (audio, sample_rate)."""
    raise NotImplementedError

CmuArcticDataset

Bases: BaseDataset

CMU ARCTIC dataset loader.

Examples:

dataset = CmuArcticDataset(Path("datasets/cmu_arctic"), speaker="bdl", download=True)
audio, fs = dataset.load_audio("arctic_a0001")
Source code in src/torchrir/datasets/cmu_arctic.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class CmuArcticDataset(BaseDataset):
    """CMU ARCTIC dataset loader.

    Examples:
        ```python
        dataset = CmuArcticDataset(Path("datasets/cmu_arctic"), speaker="bdl", download=True)
        audio, fs = dataset.load_audio("arctic_a0001")
        ```
    """

    def __init__(
        self, root: Path, speaker: str = "bdl", download: bool = False
    ) -> None:
        """Initialize a CMU ARCTIC dataset handle.

        Args:
            root: Root directory where the dataset is stored.
            speaker: Speaker ID (e.g., "bdl").
            download: Download and extract if missing.
        """
        if speaker not in VALID_SPEAKERS:
            raise ValueError(f"unsupported speaker: {speaker}")
        self.root = Path(root)
        self.speaker = speaker
        self._base_dir = self.root / "ARCTIC"
        self._archive_name = f"cmu_us_{speaker}_arctic.tar.bz2"
        self._dataset_dir = self._base_dir / f"cmu_us_{speaker}_arctic"

        if download:
            self._download_and_extract()

        if not self._dataset_dir.exists():
            raise FileNotFoundError(
                "dataset not found; run with download=True or place the archive under "
                f"{self._base_dir}"
            )

    @property
    def audio_dir(self) -> Path:
        """Return the directory containing audio files."""
        return self._dataset_dir / "wav"

    @property
    def text_path(self) -> Path:
        """Return the path to txt.done.data."""
        return self._dataset_dir / "etc" / "txt.done.data"

    def _download_and_extract(self) -> None:
        """Download and extract the speaker archive if needed."""
        self._base_dir.mkdir(parents=True, exist_ok=True)
        archive_path = self._base_dir / self._archive_name
        url = f"{BASE_URL}/{self._archive_name}"

        if not archive_path.exists():
            logger.info("Downloading %s", url)
            _download(url, archive_path)
        if not self._dataset_dir.exists():
            logger.info("Extracting %s", archive_path)
            try:
                with tarfile.open(archive_path, "r:bz2") as tar:
                    safe_extractall(tar, self._base_dir)
            except (tarfile.ReadError, EOFError, OSError) as exc:
                logger.warning("Extraction failed (%s); re-downloading.", exc)
                if archive_path.exists():
                    archive_path.unlink()
                _download(url, archive_path)
                with tarfile.open(archive_path, "r:bz2") as tar:
                    safe_extractall(tar, self._base_dir)

    def sentences(self) -> List[CmuArcticSentence]:
        """Parse all sentence metadata."""
        sentences: List[CmuArcticSentence] = []
        with self.text_path.open("r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                utt, text = _parse_text_line(line)
                sentences.append(CmuArcticSentence(utterance_id=utt, text=text))
        return sentences

    def available_sentences(self) -> List[CmuArcticSentence]:
        """Return sentences that have a corresponding wav file."""
        wav_ids = {p.stem for p in self.audio_dir.glob("*.wav")}
        return [s for s in self.sentences() if s.utterance_id in wav_ids]

    def list_speakers(self) -> List[str]:
        """Return available speaker IDs."""
        return cmu_arctic_speakers()

    def audio_path(self, utterance_id: str) -> Path:
        """Return the audio path for an utterance ID."""
        return self.audio_dir / f"{utterance_id}.wav"

    def load_audio(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
        """Load mono audio for the given utterance ID."""
        path = self.audio_path(utterance_id)
        return load(path)

    def attribution_info(self) -> DatasetAttribution:
        """Return attribution and license information for CMU ARCTIC."""
        return attribution_for("cmu_arctic")

audio_dir property

audio_dir

Return the directory containing audio files.

text_path property

text_path

Return the path to txt.done.data.

__init__

__init__(root, speaker='bdl', download=False)

Initialize a CMU ARCTIC dataset handle.

Parameters:

Name Type Description Default
root Path

Root directory where the dataset is stored.

required
speaker str

Speaker ID (e.g., "bdl").

'bdl'
download bool

Download and extract if missing.

False
Source code in src/torchrir/datasets/cmu_arctic.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def __init__(
    self, root: Path, speaker: str = "bdl", download: bool = False
) -> None:
    """Initialize a CMU ARCTIC dataset handle.

    Args:
        root: Root directory where the dataset is stored.
        speaker: Speaker ID (e.g., "bdl").
        download: Download and extract if missing.
    """
    if speaker not in VALID_SPEAKERS:
        raise ValueError(f"unsupported speaker: {speaker}")
    self.root = Path(root)
    self.speaker = speaker
    self._base_dir = self.root / "ARCTIC"
    self._archive_name = f"cmu_us_{speaker}_arctic.tar.bz2"
    self._dataset_dir = self._base_dir / f"cmu_us_{speaker}_arctic"

    if download:
        self._download_and_extract()

    if not self._dataset_dir.exists():
        raise FileNotFoundError(
            "dataset not found; run with download=True or place the archive under "
            f"{self._base_dir}"
        )

attribution_info

attribution_info()

Return attribution and license information for CMU ARCTIC.

Source code in src/torchrir/datasets/cmu_arctic.py
156
157
158
def attribution_info(self) -> DatasetAttribution:
    """Return attribution and license information for CMU ARCTIC."""
    return attribution_for("cmu_arctic")

audio_path

audio_path(utterance_id)

Return the audio path for an utterance ID.

Source code in src/torchrir/datasets/cmu_arctic.py
147
148
149
def audio_path(self, utterance_id: str) -> Path:
    """Return the audio path for an utterance ID."""
    return self.audio_dir / f"{utterance_id}.wav"

available_sentences

available_sentences()

Return sentences that have a corresponding wav file.

Source code in src/torchrir/datasets/cmu_arctic.py
138
139
140
141
def available_sentences(self) -> List[CmuArcticSentence]:
    """Return sentences that have a corresponding wav file."""
    wav_ids = {p.stem for p in self.audio_dir.glob("*.wav")}
    return [s for s in self.sentences() if s.utterance_id in wav_ids]

list_speakers

list_speakers()

Return available speaker IDs.

Source code in src/torchrir/datasets/cmu_arctic.py
143
144
145
def list_speakers(self) -> List[str]:
    """Return available speaker IDs."""
    return cmu_arctic_speakers()

load_audio

load_audio(utterance_id)

Load mono audio for the given utterance ID.

Source code in src/torchrir/datasets/cmu_arctic.py
151
152
153
154
def load_audio(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
    """Load mono audio for the given utterance ID."""
    path = self.audio_path(utterance_id)
    return load(path)

sentences

sentences()

Parse all sentence metadata.

Source code in src/torchrir/datasets/cmu_arctic.py
126
127
128
129
130
131
132
133
134
135
136
def sentences(self) -> List[CmuArcticSentence]:
    """Parse all sentence metadata."""
    sentences: List[CmuArcticSentence] = []
    with self.text_path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            utt, text = _parse_text_line(line)
            sentences.append(CmuArcticSentence(utterance_id=utt, text=text))
    return sentences

CmuArcticSentence dataclass

Sentence metadata from CMU ARCTIC.

Source code in src/torchrir/datasets/cmu_arctic.py
49
50
51
52
53
54
@dataclass
class CmuArcticSentence:
    """Sentence metadata from CMU ARCTIC."""

    utterance_id: str
    text: str

CollateBatch dataclass

Collated batch of dataset items.

Fields
  • audio: Padded audio tensor of shape (batch, max_len).
  • lengths: Original lengths for each item.
  • sample_rate: Sample rate shared across the batch.
  • utterance_ids: Utterance IDs per item.
  • texts: Optional text per item.
  • speakers: Optional speaker IDs per item.
  • metadata: Optional per-item metadata (pass-through).
Source code in src/torchrir/datasets/collate.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
@dataclass(frozen=True)
class CollateBatch:
    """Collated batch of dataset items.

    Fields:
        - audio: Padded audio tensor of shape (batch, max_len).
        - lengths: Original lengths for each item.
        - sample_rate: Sample rate shared across the batch.
        - utterance_ids: Utterance IDs per item.
        - texts: Optional text per item.
        - speakers: Optional speaker IDs per item.
        - metadata: Optional per-item metadata (pass-through).
    """

    audio: Tensor
    lengths: Tensor
    sample_rate: int
    utterance_ids: list[str]
    texts: list[Optional[str]]
    speakers: list[Optional[str]]
    metadata: Optional[list[Any]] = None

DatasetAttribution dataclass

Structured attribution info used for redistribution notices.

Source code in src/torchrir/datasets/attribution.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
@dataclass(frozen=True)
class DatasetAttribution:
    """Structured attribution info used for redistribution notices."""

    dataset_key: str
    dataset: str
    source: str
    license_name: str
    license_url: str
    required_attribution: str
    attribution_required: bool = True
    subset: Optional[str] = None

    def to_dict(self) -> dict[str, Any]:
        """Return a JSON-serializable mapping."""
        return asdict(self)

to_dict

to_dict()

Return a JSON-serializable mapping.

Source code in src/torchrir/datasets/attribution.py
22
23
24
def to_dict(self) -> dict[str, Any]:
    """Return a JSON-serializable mapping."""
    return asdict(self)

DatasetItem dataclass

Dataset item for DataLoader consumption.

Source code in src/torchrir/datasets/base.py
21
22
23
24
25
26
27
28
29
@dataclass(frozen=True)
class DatasetItem:
    """Dataset item for DataLoader consumption."""

    audio: torch.Tensor
    sample_rate: int
    utterance_id: str
    text: Optional[str] = None
    speaker: Optional[str] = None

LibriSpeechDataset

Bases: BaseDataset

LibriSpeech dataset loader.

Examples:

dataset = LibriSpeechDataset(Path("datasets/librispeech"), subset="train-clean-100", download=True)
audio, fs = dataset.load_audio("103-1240-0000")
Source code in src/torchrir/datasets/librispeech.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class LibriSpeechDataset(BaseDataset):
    """LibriSpeech dataset loader.

    Examples:
        ```python
        dataset = LibriSpeechDataset(Path("datasets/librispeech"), subset="train-clean-100", download=True)
        audio, fs = dataset.load_audio("103-1240-0000")
        ```
    """

    def __init__(
        self,
        root: Path,
        subset: str = "train-clean-100",
        speaker: str | None = None,
        download: bool = False,
    ) -> None:
        """Initialize a LibriSpeech dataset handle.

        Args:
            root: Root directory where the dataset is stored.
            subset: LibriSpeech subset name (e.g., "train-clean-100").
            speaker: Optional speaker ID directory name (e.g., "103").
                If provided, restrict loading to that speaker.
            download: Download and extract if missing.
        """
        if subset not in VALID_SUBSETS:
            raise ValueError(f"unsupported subset: {subset}")
        self.root = Path(root)
        self.subset = subset
        self.speaker = speaker
        self._archive_name = f"{subset}.tar.gz"
        self._base_dir = self.root / "LibriSpeech"
        self._subset_dir = self._base_dir / subset
        self._speaker_dir = self._subset_dir / speaker if speaker else None

        if download:
            self._download_and_extract()

        if not self._subset_dir.exists():
            raise FileNotFoundError(
                "dataset not found; run with download=True or place the archive under "
                f"{self.root}"
            )
        if self._speaker_dir is not None and not self._speaker_dir.exists():
            raise FileNotFoundError(f"speaker directory not found: {self._speaker_dir}")

    def list_speakers(self) -> List[str]:
        """Return available speaker IDs."""
        if self.speaker is not None:
            return [self.speaker]
        if not self._subset_dir.exists():
            return []
        return sorted([p.name for p in self._subset_dir.iterdir() if p.is_dir()])

    def available_sentences(self) -> List[LibriSpeechSentence]:
        """Return sentences that have a corresponding audio file."""
        sentences: List[LibriSpeechSentence] = []
        search_root = (
            self._speaker_dir if self._speaker_dir is not None else self._subset_dir
        )
        for trans_path in search_root.rglob("*.trans.txt"):
            chapter_dir = trans_path.parent
            speaker_id = chapter_dir.parent.name
            chapter_id = chapter_dir.name
            with trans_path.open("r", encoding="utf-8") as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    utt_id, text = _parse_text_line(line)
                    wav_path = chapter_dir / f"{utt_id}.flac"
                    if wav_path.exists():
                        sentences.append(
                            LibriSpeechSentence(
                                utterance_id=utt_id,
                                text=text,
                                speaker_id=speaker_id,
                                chapter_id=chapter_id,
                            )
                        )
        return sentences

    def load_audio(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
        """Load mono audio for the given utterance ID."""
        speaker_id, chapter_id, _ = utterance_id.split("-", 2)
        path = self._subset_dir / speaker_id / chapter_id / f"{utterance_id}.flac"
        return load_audio(path)

    def _download_and_extract(self) -> None:
        """Download and extract the subset archive if needed."""
        self.root.mkdir(parents=True, exist_ok=True)
        archive_path = self.root / self._archive_name
        url = f"{BASE_URL}/{self._archive_name}"

        if not archive_path.exists():
            logger.info("Downloading %s", url)
            _download(url, archive_path)
        if not self._subset_dir.exists():
            logger.info("Extracting %s", archive_path)
            try:
                with tarfile.open(archive_path, "r:gz") as tar:
                    safe_extractall(tar, self.root)
            except (tarfile.ReadError, EOFError, OSError) as exc:
                logger.warning("Extraction failed (%s); re-downloading.", exc)
                if archive_path.exists():
                    archive_path.unlink()
                _download(url, archive_path)
                with tarfile.open(archive_path, "r:gz") as tar:
                    safe_extractall(tar, self.root)

    def attribution_info(self) -> DatasetAttribution:
        """Return attribution and license information for LibriSpeech."""
        return attribution_for("librispeech", subset=self.subset)

__init__

__init__(root, subset='train-clean-100', speaker=None, download=False)

Initialize a LibriSpeech dataset handle.

Parameters:

Name Type Description Default
root Path

Root directory where the dataset is stored.

required
subset str

LibriSpeech subset name (e.g., "train-clean-100").

'train-clean-100'
speaker str | None

Optional speaker ID directory name (e.g., "103"). If provided, restrict loading to that speaker.

None
download bool

Download and extract if missing.

False
Source code in src/torchrir/datasets/librispeech.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def __init__(
    self,
    root: Path,
    subset: str = "train-clean-100",
    speaker: str | None = None,
    download: bool = False,
) -> None:
    """Initialize a LibriSpeech dataset handle.

    Args:
        root: Root directory where the dataset is stored.
        subset: LibriSpeech subset name (e.g., "train-clean-100").
        speaker: Optional speaker ID directory name (e.g., "103").
            If provided, restrict loading to that speaker.
        download: Download and extract if missing.
    """
    if subset not in VALID_SUBSETS:
        raise ValueError(f"unsupported subset: {subset}")
    self.root = Path(root)
    self.subset = subset
    self.speaker = speaker
    self._archive_name = f"{subset}.tar.gz"
    self._base_dir = self.root / "LibriSpeech"
    self._subset_dir = self._base_dir / subset
    self._speaker_dir = self._subset_dir / speaker if speaker else None

    if download:
        self._download_and_extract()

    if not self._subset_dir.exists():
        raise FileNotFoundError(
            "dataset not found; run with download=True or place the archive under "
            f"{self.root}"
        )
    if self._speaker_dir is not None and not self._speaker_dir.exists():
        raise FileNotFoundError(f"speaker directory not found: {self._speaker_dir}")

attribution_info

attribution_info()

Return attribution and license information for LibriSpeech.

Source code in src/torchrir/datasets/librispeech.py
154
155
156
def attribution_info(self) -> DatasetAttribution:
    """Return attribution and license information for LibriSpeech."""
    return attribution_for("librispeech", subset=self.subset)

available_sentences

available_sentences()

Return sentences that have a corresponding audio file.

Source code in src/torchrir/datasets/librispeech.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def available_sentences(self) -> List[LibriSpeechSentence]:
    """Return sentences that have a corresponding audio file."""
    sentences: List[LibriSpeechSentence] = []
    search_root = (
        self._speaker_dir if self._speaker_dir is not None else self._subset_dir
    )
    for trans_path in search_root.rglob("*.trans.txt"):
        chapter_dir = trans_path.parent
        speaker_id = chapter_dir.parent.name
        chapter_id = chapter_dir.name
        with trans_path.open("r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                utt_id, text = _parse_text_line(line)
                wav_path = chapter_dir / f"{utt_id}.flac"
                if wav_path.exists():
                    sentences.append(
                        LibriSpeechSentence(
                            utterance_id=utt_id,
                            text=text,
                            speaker_id=speaker_id,
                            chapter_id=chapter_id,
                        )
                    )
    return sentences

list_speakers

list_speakers()

Return available speaker IDs.

Source code in src/torchrir/datasets/librispeech.py
90
91
92
93
94
95
96
def list_speakers(self) -> List[str]:
    """Return available speaker IDs."""
    if self.speaker is not None:
        return [self.speaker]
    if not self._subset_dir.exists():
        return []
    return sorted([p.name for p in self._subset_dir.iterdir() if p.is_dir()])

load_audio

load_audio(utterance_id)

Load mono audio for the given utterance ID.

Source code in src/torchrir/datasets/librispeech.py
126
127
128
129
130
def load_audio(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
    """Load mono audio for the given utterance ID."""
    speaker_id, chapter_id, _ = utterance_id.split("-", 2)
    path = self._subset_dir / speaker_id / chapter_id / f"{utterance_id}.flac"
    return load_audio(path)

LibriSpeechSentence dataclass

Sentence metadata from LibriSpeech.

Source code in src/torchrir/datasets/librispeech.py
33
34
35
36
37
38
39
40
@dataclass
class LibriSpeechSentence:
    """Sentence metadata from LibriSpeech."""

    utterance_id: str
    text: str
    speaker_id: str
    chapter_id: str

SentenceLike

Bases: Protocol

Minimal sentence interface for dataset entries.

Source code in src/torchrir/datasets/base.py
14
15
16
17
18
class SentenceLike(Protocol):
    """Minimal sentence interface for dataset entries."""

    utterance_id: str
    text: str

attribution_for

attribution_for(dataset, subset=None)

Return attribution info for a supported dataset key.

Source code in src/torchrir/datasets/attribution.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def attribution_for(dataset: str, subset: Optional[str] = None) -> DatasetAttribution:
    """Return attribution info for a supported dataset key."""
    key = dataset.lower()
    if key == "cmu_arctic":
        return DatasetAttribution(
            dataset_key="cmu_arctic",
            dataset="CMU ARCTIC",
            source="http://www.festvox.org/cmu_arctic/",
            license_name="Permissive (attribution required; see upstream COPYING)",
            license_url="http://www.festvox.org/cmu_arctic/",
            required_attribution=(
                "Carnegie Mellon University, Language Technologies Institute (CMU ARCTIC)"
            ),
        )
    if key == "librispeech":
        return DatasetAttribution(
            dataset_key="librispeech",
            dataset="LibriSpeech (SLR12)",
            source="https://www.openslr.org/12",
            license_name="Creative Commons Attribution 4.0 International (CC BY 4.0)",
            license_url="https://creativecommons.org/licenses/by/4.0/",
            required_attribution=(
                "Vassil Panayotov, Guoguo Chen, Daniel Povey, and "
                "Sanjeev Khudanpur (LibriSpeech, 2015)"
            ),
            subset=subset,
        )
    raise ValueError(f"unsupported dataset: {dataset}")

build_dynamic_cmu_arctic_dataset

build_dynamic_cmu_arctic_dataset(*, cmu_root, dataset_root=Path('outputs/cmu_arctic_torchrir_dynamic_dataset'), speakers=DEFAULT_SPEAKERS, n_scenes=10, n_sources=3, n_moving_sources=1, duration_sec=20.0, room_size=(8.0, 6.0, 3.0), mic_center=(4.0, 3.0, 1.5), octa_edge_m=1.0, source_margin=(0.5, 0.5, 0.3), trajectory_steps=1024, rir_samples=4096, rt60=0.3, sound_speed=343.0, max_order=6, seed=42, download_cmu=False, overwrite=False, randomize_mic_center=True, move_start_ratio=0.35, move_end_ratio=0.65, moving_speed_min=0.3, moving_speed_max=0.8, save_layout_mp4=True, save_layout_mp4_3d=True, layout_video_fps=None, layout_video_mux_audio=True, save_layout_images=True, save_layout_images_3d=True, annotate_source_indices=True, logger=None)

Build a dynamic CMU ARCTIC dataset with oobss-compatible layout.

Source code in src/torchrir/datasets/dynamic_cmu_arctic.py
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
def build_dynamic_cmu_arctic_dataset(
    *,
    cmu_root: Path,
    dataset_root: Path = Path("outputs/cmu_arctic_torchrir_dynamic_dataset"),
    speakers: Sequence[str] = DEFAULT_SPEAKERS,
    n_scenes: int = 10,
    n_sources: int = 3,
    n_moving_sources: int = 1,
    duration_sec: float = 20.0,
    room_size: Sequence[float] | np.ndarray = (8.0, 6.0, 3.0),
    mic_center: Sequence[float] | np.ndarray = (4.0, 3.0, 1.5),
    octa_edge_m: float = 1.0,
    source_margin: Sequence[float] | np.ndarray = (0.5, 0.5, 0.3),
    trajectory_steps: int = 1024,
    rir_samples: int = 4096,
    rt60: float = 0.3,
    sound_speed: float = 343.0,
    max_order: int = 6,
    seed: int = 42,
    download_cmu: bool = False,
    overwrite: bool = False,
    randomize_mic_center: bool = True,
    move_start_ratio: float = 0.35,
    move_end_ratio: float = 0.65,
    moving_speed_min: float = 0.3,
    moving_speed_max: float = 0.8,
    save_layout_mp4: bool = True,
    save_layout_mp4_3d: bool = True,
    layout_video_fps: float | None = None,
    layout_video_mux_audio: bool = True,
    save_layout_images: bool = True,
    save_layout_images_3d: bool = True,
    annotate_source_indices: bool = True,
    logger: logging.Logger | None = None,
) -> tuple[int, int]:
    """Build a dynamic CMU ARCTIC dataset with oobss-compatible layout."""
    log = LOGGER if logger is None else logger
    room_size_arr = _as_triplet(room_size, name="room_size")
    mic_center_arr = _as_triplet(mic_center, name="mic_center")
    source_margin_arr = _as_triplet(source_margin, name="source_margin")
    speakers_list = [str(speaker) for speaker in speakers]

    if n_scenes <= 0:
        raise ValueError("n_scenes must be > 0")
    if n_sources <= 0:
        raise ValueError("n_sources must be > 0")
    if n_moving_sources < 0 or n_moving_sources > n_sources:
        raise ValueError(
            "n_moving_sources must satisfy 0 <= n_moving_sources <= n_sources"
        )
    if n_sources > len(speakers_list):
        raise ValueError(
            f"n_sources ({n_sources}) must be <= number of provided speakers ({len(speakers_list)})"
        )
    if trajectory_steps <= 0:
        raise ValueError("trajectory_steps must be > 0")
    if duration_sec <= 0.0:
        raise ValueError("duration_sec must be > 0")
    if rir_samples <= 0:
        raise ValueError("rir_samples must be > 0")

    min_source_distance_m = 1.8
    move_start_sec = float(duration_sec) * move_start_ratio
    move_end_sec = float(duration_sec) * move_end_ratio

    if dataset_root.exists():
        if not overwrite:
            raise FileExistsError(
                f"Dataset root already exists: {dataset_root}. Use overwrite=True."
            )
        shutil.rmtree(dataset_root)
    dataset_root.mkdir(parents=True, exist_ok=True)

    rng = np.random.default_rng(seed)
    rng_py = random.Random(seed)

    dataset_cache: dict[str, CmuArcticDataset] = {}
    sample_rate: int | None = None
    for speaker in speakers_list:
        ds = CmuArcticDataset(root=cmu_root, speaker=speaker, download=download_cmu)
        dataset_cache[speaker] = ds
        test_ids = ds.available_sentences()
        if not test_ids:
            raise RuntimeError(f"No utterances found for speaker '{speaker}'")
        _, sr = ds.load_audio(test_ids[0].utterance_id)
        if sample_rate is None:
            sample_rate = int(sr)
        elif int(sr) != sample_rate:
            raise ValueError(
                f"Sample rate mismatch across speakers: {sample_rate} vs {int(sr)}"
            )

    assert sample_rate is not None
    target_samples = int(float(duration_sec) * sample_rate)
    if target_samples <= 0:
        raise ValueError("duration_sec is too small")

    radius = float(octa_edge_m) / np.sqrt(2.0)
    base_array = polyhedron_array(
        center=[0.0, 0.0, 0.0],
        kind="octahedron",
        radius=radius,
        dtype=torch.float64,
    )
    base_mic_positions = base_array.cpu().numpy().astype(np.float64)
    n_mics = int(base_mic_positions.shape[0])
    if n_sources > n_mics:
        raise ValueError(
            f"n_sources ({n_sources}) must be <= n_mics ({n_mics}) for AuxIVA-based evaluation"
        )

    room = Room.shoebox(
        size=room_size_arr.tolist(),
        fs=float(sample_rate),
        c=float(sound_speed),
        t60=float(rt60),
        dtype=torch.float64,
    )
    convolver = DynamicConvolver(mode="trajectory")

    for scene_idx in range(n_scenes):
        scene_id = f"scene_{scene_idx:04d}"
        scene_dir = dataset_root / scene_id
        scene_dir.mkdir(parents=True, exist_ok=True)
        if randomize_mic_center:
            scene_mic_center = _sample_random_mic_center(
                rng=rng,
                room_size=room_size_arr,
                source_margin=source_margin_arr,
                min_source_distance_m=min_source_distance_m,
                array_radius_m=radius,
            )
        else:
            scene_mic_center = np.asarray(mic_center_arr, dtype=np.float64)
        mic_positions = base_mic_positions + scene_mic_center[None, :]
        mics = MicrophoneArray.from_positions(
            mic_positions.tolist(), dtype=torch.float64
        )
        mic_traj_np = np.repeat(mic_positions[None, :, :], trajectory_steps, axis=0)

        chosen_speakers = rng_py.sample(speakers_list, n_sources)
        source_signals: list[torch.Tensor] = []
        source_info: list[dict[str, object]] = []
        for speaker in chosen_speakers:
            dataset = dataset_cache[speaker]
            signal, utterance_ids = _load_fixed_length_signal(
                dataset,
                target_samples=target_samples,
                rng_py=rng_py,
            )
            source_signals.append(signal)
            source_info.append({"speaker": speaker, "utterance_ids": utterance_ids})

        dry = torch.stack(source_signals, dim=0).to(dtype=torch.float64)

        (
            starts,
            _ends,
            start_azimuth,
            end_azimuth,
            source_velocity_mps,
            angular_velocity_rad_s,
            turn_direction,
            moving_indices,
        ) = _build_constrained_source_positions(
            rng=rng,
            room_size=room_size_arr,
            mic_center=scene_mic_center,
            margin=source_margin_arr,
            n_sources=n_sources,
            n_moving_sources=n_moving_sources,
            duration_sec=float(duration_sec),
            move_start_ratio=move_start_ratio,
            move_end_ratio=move_end_ratio,
            moving_speed_min=moving_speed_min,
            moving_speed_max=moving_speed_max,
            min_radius_m=min_source_distance_m,
        )

        src_traj_np = _build_source_trajectory(
            starts=starts,
            mic_center=scene_mic_center,
            start_azimuth=start_azimuth,
            end_azimuth=end_azimuth,
            moving_indices=moving_indices,
            n_steps=trajectory_steps,
            move_start_ratio=move_start_ratio,
            move_end_ratio=move_end_ratio,
        )

        moving_index_set = set(moving_indices)
        for src_idx, item in enumerate(source_info):
            item["source_index"] = int(src_idx)
            item["is_moving"] = bool(src_idx in moving_index_set)
            item["velocity_mps"] = float(source_velocity_mps[src_idx])
            item["motion_type"] = "arc" if src_idx in moving_index_set else "static"
            item["angular_velocity_rad_s"] = float(angular_velocity_rad_s[src_idx])
            item["turn_direction"] = int(turn_direction[src_idx])
            item["move_start_sec"] = float(move_start_sec)
            item["move_end_sec"] = float(move_end_sec)

        src_traj = torch.tensor(src_traj_np, dtype=torch.float64)
        mic_traj = torch.tensor(mic_traj_np, dtype=torch.float64)

        rirs = simulate_dynamic_rir(
            room=room,
            src_traj=src_traj,
            mic_traj=mic_traj,
            max_order=max_order,
            nsample=rir_samples,
        )

        stems: list[np.ndarray] = []
        for src_idx in range(n_sources):
            stem_mc = convolver.convolve(
                dry[src_idx : src_idx + 1],
                rirs[:, src_idx : src_idx + 1, :, :],
            )
            stems.append(_to_time_channel_audio(stem_mc.cpu().numpy(), n_mics=n_mics))

        stems = _normalize_sources(stems)
        mix = np.sum(np.stack(stems, axis=0), axis=0)

        for src_idx, stem in enumerate(stems):
            sf.write(scene_dir / f"source_{src_idx:02d}.wav", stem, sample_rate)
        mixture_path = scene_dir / "mixture.wav"
        sf.write(mixture_path, mix, sample_rate)

        sources = Source.from_positions(starts.tolist(), dtype=torch.float64)
        layout_annotation_lines = _build_layout_annotation_lines(
            scene_id=scene_id,
            move_start_sec=move_start_sec,
            move_end_sec=move_end_sec,
            source_velocity_mps=source_velocity_mps,
        )

        if save_layout_images:
            save_scene_layout_images(
                out_dir=scene_dir,
                room=room.size,
                sources=sources,
                mics=mics,
                logger=log,
                src_traj=src_traj,
                mic_traj=mic_traj,
                save_2d=True,
                save_3d=save_layout_images_3d,
                annotate_sources=annotate_source_indices,
                annotation_lines=layout_annotation_lines,
            )

        if save_layout_mp4:
            save_scene_videos(
                out_dir=scene_dir,
                room=room.size,
                sources=sources,
                mics=mics,
                src_traj=src_traj,
                mic_traj=mic_traj,
                signal_len=target_samples,
                fs=sample_rate,
                logger=log,
                mp4_fps=layout_video_fps,
                save_3d=save_layout_mp4_3d,
                mixture_path=mixture_path,
                mux_audio=layout_video_mux_audio,
                annotate_sources=annotate_source_indices,
                annotation_lines=layout_annotation_lines,
            )

        save_scene_metadata(
            out_dir=scene_dir,
            metadata_name="metadata.json",
            room=room,
            sources=sources,
            mics=mics,
            rirs=rirs,
            src_traj=src_traj,
            mic_traj=mic_traj,
            signal_len=target_samples,
            source_info=source_info,
            extra={
                "scene_id": scene_id,
                "n_sources": n_sources,
                "n_moving_sources": n_moving_sources,
                "octa_edge_m": float(octa_edge_m),
                "mic_center_xyz_m": scene_mic_center.tolist(),
                "randomize_mic_center": bool(randomize_mic_center),
                "min_source_distance_from_array_center_m": float(min_source_distance_m),
                "azimuth_step_deg": float(360.0 / n_sources),
                "moving_source_indices": [int(idx) for idx in moving_indices],
                "start_azimuth_deg": np.rad2deg(start_azimuth).tolist(),
                "end_azimuth_deg": np.rad2deg(end_azimuth).tolist(),
                "source_velocity_mps": source_velocity_mps.tolist(),
                "motion_type": "arc",
                "angular_velocity_rad_s": angular_velocity_rad_s.tolist(),
                "turn_direction": turn_direction.tolist(),
                "motion_profile": {
                    "pre_static_ratio": float(move_start_ratio),
                    "move_ratio": float(move_end_ratio - move_start_ratio),
                    "post_static_ratio": float(1.0 - move_end_ratio),
                },
                "motion_time_sec": {
                    "total": float(duration_sec),
                    "move_start": float(move_start_sec),
                    "move_end": float(move_end_sec),
                },
            },
            logger=log,
        )

        with (scene_dir / "source_info.json").open("w", encoding="utf-8") as fh:
            json.dump(source_info, fh, indent=2)

        log.info(
            "Built %s | speakers=%s | sample_rate=%d",
            scene_id,
            chosen_speakers,
            sample_rate,
        )

    return sample_rate, n_mics

choose_speakers

choose_speakers(dataset, num_sources, rng)

Select unique speakers for the requested number of sources.

Examples:

rng = random.Random(0)
speakers = choose_speakers(dataset, num_sources=2, rng=rng)
Source code in src/torchrir/datasets/utils.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def choose_speakers(
    dataset: BaseDataset, num_sources: int, rng: random.Random
) -> List[str]:
    """Select unique speakers for the requested number of sources.

    Examples:
        ```python
        rng = random.Random(0)
        speakers = choose_speakers(dataset, num_sources=2, rng=rng)
        ```
    """
    speakers = dataset.list_speakers()
    if not speakers:
        raise RuntimeError("no speakers available")
    if num_sources > len(speakers):
        raise ValueError(f"num_sources must be <= {len(speakers)} for unique speakers")
    return rng.sample(speakers, num_sources)

cmu_arctic_speakers

cmu_arctic_speakers()

Return supported CMU ARCTIC speaker IDs.

Source code in src/torchrir/datasets/cmu_arctic.py
44
45
46
def cmu_arctic_speakers() -> List[str]:
    """Return supported CMU ARCTIC speaker IDs."""
    return sorted(VALID_SPEAKERS)

collate_dataset_items

collate_dataset_items(items, *, pad_value=0.0, keep_metadata=False)

Collate DatasetItem entries into a padded batch.

Parameters:

Name Type Description Default
items Iterable[DatasetItem]

Iterable of DatasetItem.

required
pad_value float

Value used for padding.

0.0
keep_metadata bool

Preserve item-level metadata field if present.

False

Returns:

Type Description
CollateBatch

CollateBatch with padded audio and metadata lists.

Source code in src/torchrir/datasets/collate.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def collate_dataset_items(
    items: Iterable[DatasetItem],
    *,
    pad_value: float = 0.0,
    keep_metadata: bool = False,
) -> CollateBatch:
    """Collate DatasetItem entries into a padded batch.

    Args:
        items: Iterable of DatasetItem.
        pad_value: Value used for padding.
        keep_metadata: Preserve item-level metadata field if present.

    Returns:
        CollateBatch with padded audio and metadata lists.
    """
    batch = list(items)
    if not batch:
        raise ValueError("collate_dataset_items received an empty batch")

    sample_rate = batch[0].sample_rate
    for item in batch[1:]:
        if item.sample_rate != sample_rate:
            raise ValueError("sample_rate must be consistent within a batch")

    lengths = torch.tensor([item.audio.numel() for item in batch], dtype=torch.long)
    max_len = int(lengths.max().item())
    audio = torch.full(
        (len(batch), max_len),
        pad_value,
        dtype=batch[0].audio.dtype,
        device=batch[0].audio.device,
    )

    for idx, item in enumerate(batch):
        audio[idx, : item.audio.numel()] = item.audio

    utterance_ids = [item.utterance_id for item in batch]
    texts = [item.text for item in batch]
    speakers = [item.speaker for item in batch]

    metadata: Optional[list[Any]] = None
    if keep_metadata:
        metadata = [getattr(item, "metadata", None) for item in batch]

    return CollateBatch(
        audio=audio,
        lengths=lengths,
        sample_rate=sample_rate,
        utterance_ids=utterance_ids,
        texts=texts,
        speakers=speakers,
        metadata=metadata,
    )

default_modification_notes

default_modification_notes(*, dynamic)

Return concise modification notes for generated outputs.

Source code in src/torchrir/datasets/attribution.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def default_modification_notes(*, dynamic: bool) -> list[str]:
    """Return concise modification notes for generated outputs."""
    notes = [
        "Utterances are concatenated and trimmed to a fixed duration per source.",
        "Outputs are derived mixtures and per-source convolved references.",
    ]
    if dynamic:
        notes.insert(
            1,
            "Dynamic room impulse responses are simulated with ISM over trajectories.",
        )
    else:
        notes.insert(
            1,
            "Static room impulse responses are simulated with ISM at fixed geometry.",
        )
    return notes

load

load(path)

Load a wav file and return mono audio and sample rate.

Notes
  • Multichannel input uses channel 0 only (warns).
  • For non-wav formats, use torchrir.io.audio.load_audio.

Examples:

audio, fs = load(Path("datasets/cmu_arctic/.../arctic_a0001.wav"))
Source code in src/torchrir/io/audio.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def load(path: Path) -> Tuple[torch.Tensor, int]:
    """Load a wav file and return mono audio and sample rate.

    Notes:
        - Multichannel input uses channel 0 only (warns).
        - For non-wav formats, use ``torchrir.io.audio.load_audio``.

    Examples:
        ```python
        audio, fs = load(Path("datasets/cmu_arctic/.../arctic_a0001.wav"))
        ```
    """
    suffix = path.suffix.lower()
    if suffix not in {".wav", ".wave"}:
        raise ValueError(
            f"load expects a wav file, got '{path.name}'. "
            "Use torchrir.io.audio.load_audio for non-wav formats."
        )
    return _load_audio(path, caller="load")

load_dataset_sources

load_dataset_sources(*, dataset_factory, num_sources, duration_s, rng)

Load and concatenate utterances for each speaker into fixed-length signals.

Examples:

from pathlib import Path
from torchrir.datasets import CmuArcticDataset
rng = random.Random(0)
root = Path("datasets/cmu_arctic")
signals, fs, info = load_dataset_sources(
    dataset_factory=lambda spk: CmuArcticDataset(root, speaker=spk, download=True),
    num_sources=2,
    duration_s=10.0,
    rng=rng,
)
Source code in src/torchrir/datasets/utils.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def load_dataset_sources(
    *,
    dataset_factory: Callable[[Optional[str]], BaseDataset],
    num_sources: int,
    duration_s: float,
    rng: random.Random,
) -> Tuple[torch.Tensor, int, List[Tuple[str, List[str]]]]:
    """Load and concatenate utterances for each speaker into fixed-length signals.

    Examples:
        ```python
        from pathlib import Path
        from torchrir.datasets import CmuArcticDataset
        rng = random.Random(0)
        root = Path("datasets/cmu_arctic")
        signals, fs, info = load_dataset_sources(
            dataset_factory=lambda spk: CmuArcticDataset(root, speaker=spk, download=True),
            num_sources=2,
            duration_s=10.0,
            rng=rng,
        )
        ```
    """
    dataset0 = dataset_factory(None)
    speakers = choose_speakers(dataset0, num_sources, rng)
    signals: List[torch.Tensor] = []
    info: List[Tuple[str, List[str]]] = []
    fs: int | None = None
    target_samples: int | None = None

    for speaker in speakers:
        dataset = dataset_factory(speaker)
        sentences: Sequence[SentenceLike] = dataset.available_sentences()
        if not sentences:
            raise RuntimeError(f"no sentences found for speaker {speaker}")

        utterance_ids: List[str] = []
        segments: List[torch.Tensor] = []
        total = 0
        sentences = list(sentences)
        rng.shuffle(sentences)
        idx = 0

        while target_samples is None or total < target_samples:
            if idx >= len(sentences):
                rng.shuffle(sentences)
                idx = 0
            sentence = sentences[idx]
            idx += 1
            audio, sample_rate = dataset.load_audio(sentence.utterance_id)
            if fs is None:
                fs = sample_rate
                target_samples = int(duration_s * fs)
            elif sample_rate != fs:
                raise ValueError(
                    f"sample rate mismatch: expected {fs}, got {sample_rate} for {speaker}"
                )
            segments.append(audio)
            utterance_ids.append(sentence.utterance_id)
            total += audio.numel()

        signal = torch.cat(segments, dim=0)[:target_samples]
        signals.append(signal)
        info.append((speaker, utterance_ids))

    stacked = torch.stack(signals, dim=0)
    if fs is None:
        raise RuntimeError("no audio loaded from dataset sources")
    return stacked, int(fs), info

save

save(path, audio, sample_rate, *, normalize=True, peak=1.0, subtype=None)

Save a mono or multi-channel wav to disk.

By default this normalizes to the specified peak. Values outside [-1, 1] are preserved when normalization is disabled. To preserve explicit file metadata, use load_audio_data / save_audio_data or pass subtype directly. For non-wav formats, use torchrir.io.audio.save_audio.

Examples:

save(Path("outputs/example.wav"), audio, sample_rate)
Source code in src/torchrir/io/audio.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def save(
    path: Path,
    audio: torch.Tensor,
    sample_rate: int,
    *,
    normalize: bool = True,
    peak: float = 1.0,
    subtype: str | None = None,
) -> None:
    """Save a mono or multi-channel wav to disk.

    By default this normalizes to the specified peak.
    Values outside [-1, 1] are preserved when normalization is disabled.
    To preserve explicit file metadata, use ``load_audio_data`` /
    ``save_audio_data`` or pass ``subtype`` directly.
    For non-wav formats, use ``torchrir.io.audio.save_audio``.

    Examples:
        ```python
        save(Path("outputs/example.wav"), audio, sample_rate)
        ```
    """
    suffix = path.suffix.lower()
    if suffix not in {".wav", ".wave"}:
        raise ValueError(
            f"save expects a wav file, got '{path.name}'. "
            "Use torchrir.io.audio.save_audio for non-wav formats."
        )
    _save_audio(
        path,
        audio,
        sample_rate,
        normalize=normalize,
        peak=peak,
        subtype=subtype,
    )

torchrir.experimental

torchrir.experimental

Experimental and work-in-progress APIs.

These APIs may change without notice. Prefer the stable interfaces in torchrir and documented submodules where possible.

__all__ module-attribute

__all__ = ['FDTDSimulator', 'RayTracingSimulator', 'TemplateDataset', 'TemplateSentence']

FDTDSimulator dataclass

Work in progress placeholder for FDTD simulation.

Goal

Provide a wave-based solver (finite-difference time-domain) with configurable grid resolution, boundary conditions, and stability constraints. The solver should target CPU/GPU execution and return RIRResult with the same metadata contract as ISM.

Source code in src/torchrir/experimental/simulators.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@dataclass(frozen=True)
class FDTDSimulator:
    """Work in progress placeholder for FDTD simulation.

    Goal:
        Provide a wave-based solver (finite-difference time-domain) with
        configurable grid resolution, boundary conditions, and stability
        constraints. The solver should target CPU/GPU execution and return
        RIRResult with the same metadata contract as ISM.
    """

    def __post_init__(self) -> None:
        warnings.warn(
            "FDTDSimulator is experimental and not implemented.",
            RuntimeWarning,
            stacklevel=2,
        )

    def simulate(
        self, scene: SceneLike, config: SimulationConfig | None = None
    ) -> RIRResult:
        raise NotImplementedError("FDTDSimulator is not implemented yet")

RayTracingSimulator dataclass

Work in progress placeholder for ray tracing simulation.

Goal

Provide a geometric acoustics backend that traces specular/diffuse reflection paths, supports frequency-dependent absorption/scattering, and returns a RIRResult compatible with the ISM path. The intent is to reuse Scene/SimulationConfig for inputs and keep output shape parity.

Source code in src/torchrir/experimental/simulators.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
@dataclass(frozen=True)
class RayTracingSimulator:
    """Work in progress placeholder for ray tracing simulation.

    Goal:
        Provide a geometric acoustics backend that traces specular/diffuse
        reflection paths, supports frequency-dependent absorption/scattering,
        and returns a RIRResult compatible with the ISM path. The intent is to
        reuse Scene/SimulationConfig for inputs and keep output shape parity.
    """

    def __post_init__(self) -> None:
        warnings.warn(
            "RayTracingSimulator is experimental and not implemented.",
            RuntimeWarning,
            stacklevel=2,
        )

    def simulate(
        self, scene: SceneLike, config: SimulationConfig | None = None
    ) -> RIRResult:
        raise NotImplementedError("RayTracingSimulator is not implemented yet")

TemplateDataset

Bases: BaseDataset

Template dataset stub for future integrations.

This class is a placeholder to document the expected dataset API surface. It will be replaced with concrete dataset loaders in future releases.

Source code in src/torchrir/experimental/datasets.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class TemplateDataset(BaseDataset):
    """Template dataset stub for future integrations.

    This class is a placeholder to document the expected dataset API surface.
    It will be replaced with concrete dataset loaders in future releases.
    """

    def __init__(self) -> None:
        warnings.warn(
            "TemplateDataset is experimental and not implemented.",
            RuntimeWarning,
            stacklevel=2,
        )

    def __len__(self) -> int:
        raise NotImplementedError("TemplateDataset is not implemented yet")

    def __getitem__(self, idx) -> DatasetItem:
        raise NotImplementedError("TemplateDataset is not implemented yet")

attribution_info

attribution_info()

Return attribution and license information for this dataset.

Source code in src/torchrir/datasets/base.py
49
50
51
def attribution_info(self) -> DatasetAttribution:
    """Return attribution and license information for this dataset."""
    raise NotImplementedError

available_sentences

available_sentences()

Return sentence entries that have audio available.

Source code in src/torchrir/datasets/base.py
41
42
43
def available_sentences(self) -> Sequence[SentenceLike]:
    """Return sentence entries that have audio available."""
    raise NotImplementedError

list_speakers

list_speakers()

Return available speaker IDs.

Source code in src/torchrir/datasets/base.py
37
38
39
def list_speakers(self) -> list[str]:
    """Return available speaker IDs."""
    raise NotImplementedError

load_audio

load_audio(utterance_id)

Load audio for an utterance and return (audio, sample_rate).

Source code in src/torchrir/datasets/base.py
45
46
47
def load_audio(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
    """Load audio for an utterance and return (audio, sample_rate)."""
    raise NotImplementedError

TemplateSentence dataclass

Template for dataset sentences (work in progress).

Source code in src/torchrir/experimental/datasets.py
12
13
14
15
16
17
18
19
@dataclass(frozen=True)
class TemplateSentence:
    """Template for dataset sentences (work in progress)."""

    utterance_id: str
    speaker: str
    text: str
    wav_path: Path