-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathmulti_mic.py
More file actions
1589 lines (1319 loc) · 53.4 KB
/
multi_mic.py
File metadata and controls
1589 lines (1319 loc) · 53.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Multi-microphone components.
This library contains functions for multi-microphone signal processing.
Example
-------
>>> import torch
>>>
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT, ISTFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import GccPhat, SrpPhat, Music
>>> from speechbrain.processing.multi_mic import DelaySum, Mvdr, Gev
>>>
>>> xs_speech = read_audio(
... "tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac"
... )
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channels]
>>> xs_noise_diff = read_audio("tests/samples/multi-mic/noise_diffuse.flac")
>>> xs_noise_diff = xs_noise_diff.unsqueeze(0)
>>> xs_noise_loc = read_audio(
... "tests/samples/multi-mic/noise_0.70225_-0.70225_0.11704.flac"
... )
>>> xs_noise_loc = xs_noise_loc.unsqueeze(0)
>>> fs = 16000 # sampling rate
>>> ss = xs_speech
>>> nn_diff = 0.05 * xs_noise_diff
>>> nn_loc = 0.05 * xs_noise_loc
>>> xs_diffused_noise = ss + nn_diff
>>> xs_localized_noise = ss + nn_loc
>>> # Delay-and-Sum Beamforming with GCC-PHAT localization
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>> delaysum = DelaySum()
>>> istft = ISTFT(sample_rate=fs)
>>> Xs = stft(xs_diffused_noise)
>>> Ns = stft(nn_diff)
>>> XXs = cov(Xs)
>>> NNs = cov(Ns)
>>> tdoas = gccphat(XXs)
>>> Ys_ds = delaysum(Xs, tdoas)
>>> ys_ds = istft(Ys_ds)
>>> # Mvdr Beamforming with SRP-PHAT localization
>>> mvdr = Mvdr()
>>> mics = torch.zeros((4, 3), dtype=torch.float)
>>> mics[0, :] = torch.FloatTensor([-0.05, -0.05, +0.00])
>>> mics[1, :] = torch.FloatTensor([-0.05, +0.05, +0.00])
>>> mics[2, :] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> mics[3, :] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> srpphat = SrpPhat(mics=mics)
>>> doas = srpphat(XXs)
>>> Ys_mvdr = mvdr(Xs, NNs, doas, doa_mode=True, mics=mics, fs=fs)
>>> ys_mvdr = istft(Ys_mvdr)
>>> # Mvdr Beamforming with MUSIC localization
>>> music = Music(mics=mics)
>>> doas = music(XXs)
>>> Ys_mvdr2 = mvdr(Xs, NNs, doas, doa_mode=True, mics=mics, fs=fs)
>>> ys_mvdr2 = istft(Ys_mvdr2)
>>> # GeV Beamforming
>>> gev = Gev()
>>> Xs = stft(xs_localized_noise)
>>> Ss = stft(ss)
>>> Ns = stft(nn_loc)
>>> SSs = cov(Ss)
>>> NNs = cov(Ns)
>>> Ys_gev = gev(Xs, SSs, NNs)
>>> ys_gev = istft(Ys_gev)
Authors:
* William Aris
* Francois Grondin
"""
import torch
import speechbrain.processing.decomposition as eig
class Covariance(torch.nn.Module):
"""Computes the covariance matrices of the signals.
Arguments
---------
average : bool
Informs the module if it should return an average
(computed on the time dimension) of the covariance
matrices. The Default value is True.
Example
-------
>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT
>>> from speechbrain.processing.multi_mic import Covariance
>>>
>>> xs_speech = read_audio(
... "tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac"
... )
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channels]
>>> xs_noise = read_audio("tests/samples/multi-mic/noise_diffuse.flac")
>>> xs_noise = xs_noise.unsqueeze(0)
>>> xs = xs_speech + 0.05 * xs_noise
>>> fs = 16000
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>>
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> XXs.shape
torch.Size([1, 1001, 201, 2, 10])
"""
def __init__(self, average=True):
super().__init__()
self.average = average
def forward(self, Xs):
"""This method uses the utility function _cov to compute covariance
matrices. Therefore, the result has the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics + n_pairs).
The order on the last dimension corresponds to the triu_indices for a
square matrix. For instance, if we have 4 channels, we get the following
order: (0, 0), (0, 1), (0, 2), (0, 3), (1, 1), (1, 2), (1, 3), (2, 2), (2, 3)
and (3, 3). Therefore, XXs[..., 0] corresponds to channels (0, 0) and XXs[..., 1]
corresponds to channels (0, 1).
Arguments:
----------
Xs : torch.Tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics)
"""
XXs = Covariance._cov(Xs=Xs, average=self.average)
return XXs
@staticmethod
def _cov(Xs, average=True):
"""Computes the covariance matrices (XXs) of the signals. The result will
have the following format: (batch, time_step, n_fft/2 + 1, 2, n_mics + n_pairs).
Arguments:
----------
Xs : torch.Tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics)
average : boolean
Informs the function if it should return an average
(computed on the time dimension) of the covariance
matrices. Default value is True.
"""
# Get useful dimensions
n_mics = Xs.shape[4]
# Formatting the real and imaginary parts
Xs_re = Xs[..., 0, :].unsqueeze(4)
Xs_im = Xs[..., 1, :].unsqueeze(4)
# Computing the covariance
Rxx_re = torch.matmul(Xs_re, Xs_re.transpose(3, 4)) + torch.matmul(
Xs_im, Xs_im.transpose(3, 4)
)
Rxx_im = torch.matmul(Xs_re, Xs_im.transpose(3, 4)) - torch.matmul(
Xs_im, Xs_re.transpose(3, 4)
)
# Selecting the upper triangular part of the covariance matrices
idx = torch.triu_indices(n_mics, n_mics)
XXs_re = Rxx_re[..., idx[0], idx[1]]
XXs_im = Rxx_im[..., idx[0], idx[1]]
XXs = torch.stack((XXs_re, XXs_im), 3)
# Computing the average if desired
if average is True:
n_time_frames = XXs.shape[1]
XXs = torch.mean(XXs, 1, keepdim=True)
XXs = XXs.repeat(1, n_time_frames, 1, 1, 1)
return XXs
class DelaySum(torch.nn.Module):
"""Performs delay and sum beamforming by using the TDOAs and
the first channel as a reference.
Example
-------
>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT, ISTFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import GccPhat, DelaySum
>>>
>>> xs_speech = read_audio(
... "tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac"
... )
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channel]
>>> xs_noise = read_audio("tests/samples/multi-mic/noise_diffuse.flac")
>>> xs_noise = xs_noise.unsqueeze(0) # [batch, time, channels]
>>> fs = 16000
>>> xs = xs_speech + 0.05 * xs_noise
>>>
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>> delaysum = DelaySum()
>>> istft = ISTFT(sample_rate=fs)
>>>
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> tdoas = gccphat(XXs)
>>> Ys = delaysum(Xs, tdoas)
>>> ys = istft(Ys)
"""
def __init__(self):
super().__init__()
def forward(
self,
Xs,
localization_tensor,
doa_mode=False,
mics=None,
fs=None,
c=343.0,
):
"""This method computes a steering vector by using the TDOAs/DOAs and
then calls the utility function _delaysum to perform beamforming.
The result has the following format: (batch, time_step, n_fft, 2, 1).
Arguments
---------
Xs : torch.Tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics)
localization_tensor : torch.Tensor
A tensor containing either time differences of arrival (TDOAs)
(in samples) for each timestamp or directions of arrival (DOAs)
(xyz coordinates in meters). If localization_tensor represents
TDOAs, then its format is (batch, time_steps, n_mics + n_pairs).
If localization_tensor represents DOAs, then its format is
(batch, time_steps, 3)
doa_mode : bool
The user needs to set this parameter to True if localization_tensor
represents DOAs instead of TDOAs. Its default value is set to False.
mics : torch.Tensor
The cartesian position (xyz coordinates in meters) of each microphone.
The tensor must have the following format (n_mics, 3). This
parameter is only mandatory when localization_tensor represents
DOAs.
fs : int
The sample rate in Hertz of the signals. This parameter is only
mandatory when localization_tensor represents DOAs.
c : float
The speed of sound in the medium. The speed is expressed in meters
per second and the default value of this parameter is 343 m/s. This
parameter is only used when localization_tensor represents DOAs.
Returns
-------
Ys : torch.Tensor
"""
# Get useful dimensions
n_fft = Xs.shape[2]
localization_tensor = localization_tensor.to(Xs.device)
# Convert the tdoas to taus
if doa_mode:
taus = doas2taus(doas=localization_tensor, mics=mics, fs=fs, c=c)
else:
taus = tdoas2taus(tdoas=localization_tensor)
# Generate the steering vector
As = steering(taus=taus, n_fft=n_fft)
# Apply delay and sum
Ys = DelaySum._delaysum(Xs=Xs, As=As)
return Ys
@staticmethod
def _delaysum(Xs, As):
"""Perform delay and sum beamforming. The result has
the following format: (batch, time_step, n_fft, 2, 1).
Arguments
---------
Xs : torch.Tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics)
As : torch.Tensor
The steering vector to point in the direction of
the target source. The tensor must have the format
(batch, time_step, n_fft/2 + 1, 2, n_mics)
Returns
-------
Ys : torch.Tensor
"""
# Get useful dimensions
n_mics = Xs.shape[4]
# Generate unmixing coefficients
Ws_re = As[..., 0, :] / n_mics
Ws_im = -1 * As[..., 1, :] / n_mics
# Get input signal
Xs_re = Xs[..., 0, :]
Xs_im = Xs[..., 1, :]
# Applying delay and sum
Ys_re = torch.sum((Ws_re * Xs_re - Ws_im * Xs_im), dim=3, keepdim=True)
Ys_im = torch.sum((Ws_re * Xs_im + Ws_im * Xs_re), dim=3, keepdim=True)
# Assembling the result
Ys = torch.stack((Ys_re, Ys_im), 3)
return Ys
class Mvdr(torch.nn.Module):
"""Perform minimum variance distortionless response (MVDR) beamforming
by using an input signal in the frequency domain, its covariance matrices
and tdoas (to compute a steering vector).
Example
-------
>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT, ISTFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import GccPhat, DelaySum
>>>
>>> xs_speech = read_audio(
... "tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac"
... )
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channel]
>>> xs_noise = read_audio("tests/samples/multi-mic/noise_diffuse.flac")
>>> xs_noise = xs_noise.unsqueeze(0) # [batch, time, channels]
>>> fs = 16000
>>> xs = xs_speech + 0.05 * xs_noise
>>>
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>> mvdr = Mvdr()
>>> istft = ISTFT(sample_rate=fs)
>>>
>>> Xs = stft(xs)
>>> Ns = stft(xs_noise)
>>> XXs = cov(Xs)
>>> NNs = cov(Ns)
>>> tdoas = gccphat(XXs)
>>> Ys = mvdr(Xs, NNs, tdoas)
>>> ys = istft(Ys)
"""
def __init__(self, eps=1e-20):
super().__init__()
self.eps = eps
def forward(
self,
Xs,
NNs,
localization_tensor,
doa_mode=False,
mics=None,
fs=None,
c=343.0,
):
"""This method computes a steering vector before using the
utility function _mvdr to perform beamforming. The result has
the following format: (batch, time_step, n_fft, 2, 1).
Arguments
---------
Xs : torch.Tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics)
NNs : torch.Tensor
The covariance matrices of the noise signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs)
localization_tensor : torch.Tensor
A tensor containing either time differences of arrival (TDOAs)
(in samples) for each timestamp or directions of arrival (DOAs)
(xyz coordinates in meters). If localization_tensor represents
TDOAs, then its format is (batch, time_steps, n_mics + n_pairs).
If localization_tensor represents DOAs, then its format is
(batch, time_steps, 3)
doa_mode : bool
The user needs to set this parameter to True if localization_tensor
represents DOAs instead of TDOAs. Its default value is set to False.
mics : torch.Tensor
The cartesian position (xyz coordinates in meters) of each microphone.
The tensor must have the following format (n_mics, 3). This
parameter is only mandatory when localization_tensor represents
DOAs.
fs : int
The sample rate in Hertz of the signals. This parameter is only
mandatory when localization_tensor represents DOAs.
c : float
The speed of sound in the medium. The speed is expressed in meters
per second and the default value of this parameter is 343 m/s. This
parameter is only used when localization_tensor represents DOAs.
Returns
-------
Ys : torch.Tensor
"""
# Get useful dimensions
n_fft = Xs.shape[2]
localization_tensor = localization_tensor.to(Xs.device)
NNs = NNs.to(Xs.device)
if mics is not None:
mics = mics.to(Xs.device)
# Convert the tdoas to taus
if doa_mode:
taus = doas2taus(doas=localization_tensor, mics=mics, fs=fs, c=c)
else:
taus = tdoas2taus(tdoas=localization_tensor)
# Generate the steering vector
As = steering(taus=taus, n_fft=n_fft)
# Perform mvdr
Ys = Mvdr._mvdr(Xs=Xs, NNs=NNs, As=As)
return Ys
@staticmethod
def _mvdr(Xs, NNs, As, eps=1e-20):
"""Perform minimum variance distortionless response beamforming.
Arguments
---------
Xs : torch.Tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics).
NNs : torch.Tensor
The covariance matrices of the noise signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
As : torch.Tensor
The steering vector to point in the direction of
the target source. The tensor must have the format
(batch, time_step, n_fft/2 + 1, 2, n_mics).
eps : float
A small value to avoid division by zero.
Returns
-------
Ys : torch.Tensor
"""
# Get unique covariance values to reduce the number of computations
NNs_val, NNs_idx = torch.unique(NNs, return_inverse=True, dim=1)
# Inverse covariance matrices
NNs_inv = eig.inv(NNs_val)
# Capture real and imaginary parts, and restore time steps
NNs_inv_re = NNs_inv[..., 0][:, NNs_idx]
NNs_inv_im = NNs_inv[..., 1][:, NNs_idx]
# Decompose steering vector
AsC_re = As[..., 0, :].unsqueeze(4)
AsC_im = 1.0 * As[..., 1, :].unsqueeze(4)
AsT_re = AsC_re.transpose(3, 4)
AsT_im = -1.0 * AsC_im.transpose(3, 4)
# Project
NNs_inv_AsC_re = torch.matmul(NNs_inv_re, AsC_re) - torch.matmul(
NNs_inv_im, AsC_im
)
NNs_inv_AsC_im = torch.matmul(NNs_inv_re, AsC_im) + torch.matmul(
NNs_inv_im, AsC_re
)
# Compute the gain
alpha = 1.0 / (
torch.matmul(AsT_re, NNs_inv_AsC_re)
- torch.matmul(AsT_im, NNs_inv_AsC_im)
)
# Get the unmixing coefficients
Ws_re = torch.matmul(NNs_inv_AsC_re, alpha).squeeze(4)
Ws_im = -torch.matmul(NNs_inv_AsC_im, alpha).squeeze(4)
# Applying MVDR
Xs_re = Xs[..., 0, :]
Xs_im = Xs[..., 1, :]
Ys_re = torch.sum((Ws_re * Xs_re - Ws_im * Xs_im), dim=3, keepdim=True)
Ys_im = torch.sum((Ws_re * Xs_im + Ws_im * Xs_re), dim=3, keepdim=True)
Ys = torch.stack((Ys_re, Ys_im), -2)
return Ys
class Gev(torch.nn.Module):
"""Generalized EigenValue decomposition (GEV) Beamforming.
Example
-------
>>> from speechbrain.dataio.dataio import read_audio
>>> import torch
>>>
>>> from speechbrain.processing.features import STFT, ISTFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import Gev
>>>
>>> xs_speech = read_audio(
... "tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac"
... )
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channels]
>>> xs_noise = read_audio(
... "tests/samples/multi-mic/noise_0.70225_-0.70225_0.11704.flac"
... )
>>> xs_noise = xs_noise.unsqueeze(0)
>>> fs = 16000
>>> ss = xs_speech
>>> nn = 0.05 * xs_noise
>>> xs = ss + nn
>>>
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gev = Gev()
>>> istft = ISTFT(sample_rate=fs)
>>>
>>> Ss = stft(ss)
>>> Nn = stft(nn)
>>> Xs = stft(xs)
>>>
>>> SSs = cov(Ss)
>>> NNs = cov(Nn)
>>>
>>> Ys = gev(Xs, SSs, NNs)
>>> ys = istft(Ys)
"""
def __init__(self):
super().__init__()
def forward(self, Xs, SSs, NNs):
"""This method uses the utility function _gev to perform generalized
eigenvalue decomposition beamforming. Therefore, the result has
the following format: (batch, time_step, n_fft, 2, 1).
Arguments
---------
Xs : torch.Tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics).
SSs : torch.Tensor
The covariance matrices of the target signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
NNs : torch.Tensor
The covariance matrices of the noise signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
Returns
-------
Ys : torch.Tensor
"""
Ys = Gev._gev(Xs=Xs, SSs=SSs, NNs=NNs)
return Ys
@staticmethod
def _gev(Xs, SSs, NNs):
"""Perform generalized eigenvalue decomposition beamforming. The result
has the following format: (batch, time_step, n_fft, 2, 1).
Arguments
---------
Xs : torch.Tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics).
SSs : torch.Tensor
The covariance matrices of the target signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
NNs : torch.Tensor
The covariance matrices of the noise signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
Returns
-------
Ys : torch.Tensor
"""
# Putting on the right device
SSs = SSs.to(Xs.device)
NNs = NNs.to(Xs.device)
# Get useful dimensions
n_mics = Xs.shape[4]
n_mics_pairs = SSs.shape[4]
# Computing the eigenvectors
SSs_NNs = torch.cat((SSs, NNs), dim=4)
SSs_NNs_val, SSs_NNs_idx = torch.unique(
SSs_NNs, return_inverse=True, dim=1
)
SSs = SSs_NNs_val[..., range(0, n_mics_pairs)]
NNs = SSs_NNs_val[..., range(n_mics_pairs, 2 * n_mics_pairs)]
NNs = eig.pos_def(NNs)
Vs, Ds = eig.gevd(SSs, NNs)
# Beamforming
F_re = Vs[..., (n_mics - 1), 0]
F_im = Vs[..., (n_mics - 1), 1]
# Normalize
F_norm = 1.0 / (
torch.sum(F_re**2 + F_im**2, dim=3, keepdim=True) ** 0.5
).repeat(1, 1, 1, n_mics)
F_re *= F_norm
F_im *= F_norm
Ws_re = F_re[:, SSs_NNs_idx]
Ws_im = F_im[:, SSs_NNs_idx]
Xs_re = Xs[..., 0, :]
Xs_im = Xs[..., 1, :]
Ys_re = torch.sum((Ws_re * Xs_re - Ws_im * Xs_im), dim=3, keepdim=True)
Ys_im = torch.sum((Ws_re * Xs_im + Ws_im * Xs_re), dim=3, keepdim=True)
# Assembling the output
Ys = torch.stack((Ys_re, Ys_im), 3)
return Ys
class GccPhat(torch.nn.Module):
"""Generalized Cross-Correlation with Phase Transform localization.
Arguments
---------
tdoa_max : int
Specifies a range to search for delays. For example, if
tdoa_max = 10, the method will restrict its search for delays
between -10 and 10 samples. This parameter is optional and its
default value is None. When tdoa_max is None, the method will
search for delays between -n_fft/2 and n_fft/2 (full range).
eps : float
A small value to avoid divisions by 0 with the phase transformation.
The default value is 1e-20.
Example
-------
>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT, ISTFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import GccPhat, DelaySum
>>>
>>> xs_speech = read_audio(
... "tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac"
... )
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channel]
>>> xs_noise = read_audio("tests/samples/multi-mic/noise_diffuse.flac")
>>> xs_noise = xs_noise.unsqueeze(0) # [batch, time, channels]
>>> fs = 16000
>>> xs = xs_speech + 0.05 * xs_noise
>>>
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> tdoas = gccphat(XXs)
"""
def __init__(self, tdoa_max=None, eps=1e-20):
super().__init__()
self.tdoa_max = tdoa_max
self.eps = eps
def forward(self, XXs):
"""Perform generalized cross-correlation with phase transform localization
by using the utility function _gcc_phat and by extracting the delays (in samples)
before performing a quadratic interpolation to improve the accuracy.
The result has the format: (batch, time_steps, n_mics + n_pairs).
The order on the last dimension corresponds to the triu_indices for a
square matrix. For instance, if we have 4 channels, we get the following
order: (0, 0), (0, 1), (0, 2), (0, 3), (1, 1), (1, 2), (1, 3), (2, 2), (2, 3)
and (3, 3). Therefore, delays[..., 0] corresponds to channels (0, 0) and delays[..., 1]
corresponds to channels (0, 1).
Arguments:
----------
XXs : torch.Tensor
The covariance matrices of the input signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
"""
xxs = GccPhat._gcc_phat(XXs=XXs, eps=self.eps)
delays = GccPhat._extract_delays(xxs=xxs, tdoa_max=self.tdoa_max)
tdoas = GccPhat._interpolate(xxs=xxs, delays=delays)
return tdoas
@staticmethod
def _gcc_phat(XXs, eps=1e-20):
"""Evaluate GCC-PHAT for each timestamp. It returns the result in the time
domain. The result has the format: (batch, time_steps, n_fft, n_mics + n_pairs).
Arguments
---------
XXs : torch.Tensor
The covariance matrices of the input signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
eps : float
A small value to avoid divisions by 0 with the phase transform. The
default value is 1e-20.
Returns
-------
xxs : torch.Tensor
"""
# Get useful dimensions
n_samples = (XXs.shape[2] - 1) * 2
# Extracting the tensors needed
XXs_val, XXs_idx = torch.unique(XXs, return_inverse=True, dim=4)
XXs_re = XXs_val[..., 0, :]
XXs_im = XXs_val[..., 1, :]
# Applying the phase transform
XXs_abs = torch.sqrt(XXs_re**2 + XXs_im**2) + eps
XXs_re_phat = XXs_re / XXs_abs
XXs_im_phat = XXs_im / XXs_abs
XXs_phat = torch.stack((XXs_re_phat, XXs_im_phat), 4)
# Returning in the temporal domain
XXs_phat = XXs_phat.transpose(2, 3)
XXs_phat = torch.complex(XXs_phat[..., 0], XXs_phat[..., 1])
xxs = torch.fft.irfft(XXs_phat, n=n_samples)
xxs = xxs[..., XXs_idx, :]
# Formatting the output
xxs = xxs.transpose(2, 3)
return xxs
@staticmethod
def _extract_delays(xxs, tdoa_max=None):
"""Extract the rounded delays from the cross-correlation for each timestamp.
The result has the format: (batch, time_steps, n_mics + n_pairs).
Arguments
---------
xxs : torch.Tensor
The correlation signals obtained after a gcc-phat operation. The tensor
must have the format (batch, time_steps, n_fft, n_mics + n_pairs).
tdoa_max : int
Specifies a range to search for delays. For example, if
tdoa_max = 10, the method will restrict its search for delays
between -10 and 10 samples. This parameter is optional and its
default value is None. When tdoa_max is None, the method will
search for delays between -n_fft/2 and +n_fft/2 (full range).
Returns
-------
delays : torch.Tensor
"""
# Get useful dimensions
n_fft = xxs.shape[2]
# If no tdoa specified, cover the whole frame
if tdoa_max is None:
tdoa_max = torch.div(n_fft, 2, rounding_mode="floor")
# Splitting the GCC-PHAT values to search in the range
slice_1 = xxs[..., 0:tdoa_max, :]
slice_2 = xxs[..., -tdoa_max:, :]
xxs_sliced = torch.cat((slice_1, slice_2), 2)
# Extracting the delays in the range
_, delays = torch.max(xxs_sliced, 2)
# Adjusting the delays that were affected by the slicing
offset = n_fft - xxs_sliced.shape[2]
idx = delays >= slice_1.shape[2]
delays[idx] += offset
# Centering the delays around 0
delays[idx] -= n_fft
return delays
@staticmethod
def _interpolate(xxs, delays):
"""Perform quadratic interpolation on the cross-correlation to
improve the tdoa accuracy. The result has the format:
(batch, time_steps, n_mics + n_pairs)
Arguments
---------
xxs : torch.Tensor
The correlation signals obtained after a gcc-phat operation. The tensor
must have the format (batch, time_steps, n_fft, n_mics + n_pairs).
delays : torch.Tensor
The rounded tdoas obtained by selecting the sample with the highest
amplitude. The tensor must have the format
(batch, time_steps, n_mics + n_pairs).
Returns
-------
delays_frac : torch.Tensor
"""
# Get useful dimensions
n_fft = xxs.shape[2]
# Get the max amplitude and its neighbours
tp = torch.fmod((delays - 1) + n_fft, n_fft).unsqueeze(2)
y1 = torch.gather(xxs, 2, tp).squeeze(2)
tp = torch.fmod(delays + n_fft, n_fft).unsqueeze(2)
y2 = torch.gather(xxs, 2, tp).squeeze(2)
tp = torch.fmod((delays + 1) + n_fft, n_fft).unsqueeze(2)
y3 = torch.gather(xxs, 2, tp).squeeze(2)
# Add a fractional part to the initially rounded delay
delays_frac = delays + (y1 - y3) / (2 * y1 - 4 * y2 + 2 * y3)
return delays_frac
class SrpPhat(torch.nn.Module):
"""Steered-Response Power with Phase Transform Localization.
Arguments
---------
mics : torch.Tensor
The cartesian coordinates (xyz) in meters of each microphone.
The tensor must have the following format (n_mics, 3).
space : string
If this parameter is set to 'sphere', the localization will
be done in 3D by searching in a sphere of possible doas. If
it set to 'circle', the search will be done in 2D by searching
in a circle. By default, this parameter is set to 'sphere'.
Note: The 'circle' option isn't implemented yet.
sample_rate : int
The sample rate in Hertz of the signals to perform SRP-PHAT on.
By default, this parameter is set to 16000 Hz.
speed_sound : float
The speed of sound in the medium. The speed is expressed in meters
per second and the default value of this parameter is 343 m/s.
eps : float
A small value to avoid errors like division by 0. The default value
of this parameter is 1e-20.
Example
-------
>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import SrpPhat
>>> xs_speech = read_audio(
... "tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac"
... )
>>> xs_noise = read_audio("tests/samples/multi-mic/noise_diffuse.flac")
>>> fs = 16000
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channels]
>>> xs_noise = xs_noise.unsqueeze(0)
>>> ss1 = xs_speech
>>> ns1 = 0.05 * xs_noise
>>> xs1 = ss1 + ns1
>>> ss2 = xs_speech
>>> ns2 = 0.20 * xs_noise
>>> xs2 = ss2 + ns2
>>> ss = torch.cat((ss1, ss2), dim=0)
>>> ns = torch.cat((ns1, ns2), dim=0)
>>> xs = torch.cat((xs1, xs2), dim=0)
>>> mics = torch.zeros((4, 3), dtype=torch.float)
>>> mics[0, :] = torch.FloatTensor([-0.05, -0.05, +0.00])
>>> mics[1, :] = torch.FloatTensor([-0.05, +0.05, +0.00])
>>> mics[2, :] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> mics[3, :] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> srpphat = SrpPhat(mics=mics)
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> doas = srpphat(XXs)
"""
def __init__(
self,
mics,
space="sphere",
sample_rate=16000,
speed_sound=343.0,
eps=1e-20,
):
super().__init__()
# Generate the doas
if space == "sphere":
self.doas = sphere()
if space == "circle":
pass
# Generate associated taus with the doas
self.taus = doas2taus(
self.doas, mics=mics, fs=sample_rate, c=speed_sound
)
# Save epsilon
self.eps = eps
def forward(self, XXs):
"""Perform SRP-PHAT localization on a signal by computing a steering
vector and then by using the utility function _srp_phat to extract the doas.
The result is a tensor containing the directions of arrival (xyz coordinates
(in meters) in the direction of the sound source). The output tensor
has the format (batch, time_steps, 3).
This localization method uses Global Coherence Field (GCF):
https://www.researchgate.net/publication/221491705_Speaker_localization_based_on_oriented_global_coherence_field
Arguments
---------
XXs : torch.Tensor
The covariance matrices of the input signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
Returns
-------
doas : torch.Tensor
"""
# Get useful dimensions
n_fft = XXs.shape[2]
# Generate the steering vector
As = steering(self.taus.to(XXs.device), n_fft)
# Perform srp-phat
doas = SrpPhat._srp_phat(XXs=XXs, As=As, doas=self.doas, eps=self.eps)
return doas
@staticmethod
def _srp_phat(XXs, As, doas, eps=1e-20):
"""Perform srp-phat to find the direction of arrival
of the sound source. The result is a tensor containing the directions
of arrival (xyz coordinates (in meters) in the direction of the sound source).
The output tensor has the format: (batch, time_steps, 3).
Arguments
---------
XXs : torch.Tensor
The covariance matrices of the input signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
As : torch.Tensor
The steering vector that cover the all the potential directions
of arrival. The tensor must have the format