Xingyu Bian commited on
Commit
9437579
1 Parent(s): fa77754

added diarization plot

Browse files
Files changed (3) hide show
  1. .gitignore +3 -0
  2. app.py +61 -7
  3. requirements.txt +1 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ __pycache__/
3
+ flagged/
app.py CHANGED
@@ -5,6 +5,7 @@ import numpy as np
5
  from pyannote.audio import Pipeline
6
  import os
7
  from dotenv import load_dotenv
 
8
 
9
  load_dotenv()
10
 
@@ -38,7 +39,51 @@ diarization_pipeline = Pipeline.from_pretrained(
38
  )
39
 
40
 
41
- def transcribe(audio):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  sr, data = audio
43
  processed_data = np.array(data).astype(np.float32) / 32767.0
44
  waveform_tensor = torch.tensor(processed_data[np.newaxis, :])
@@ -49,19 +94,28 @@ def transcribe(audio):
49
  {"waveform": waveform_tensor, "sample_rate": sr}
50
  )
51
 
52
- return transcription_res, diarization_res
 
53
 
 
 
54
 
 
 
 
 
55
  demo = gr.Interface(
56
- fn=transcribe,
57
  inputs=gr.Audio(sources=["upload", "microphone"]),
58
  outputs=[
59
- gr.Textbox(lines=3, info="audio transcription"),
60
- gr.Textbox(info="speaker diarization"),
 
61
  ],
62
- title="Automatic Speech Recognition 🗣️",
63
- description="Transcribe your speech to text with distilled whisper",
64
  )
65
 
 
66
  if __name__ == "__main__":
67
  demo.launch()
 
5
  from pyannote.audio import Pipeline
6
  import os
7
  from dotenv import load_dotenv
8
+ import plotly.graph_objects as go
9
 
10
  load_dotenv()
11
 
 
39
  )
40
 
41
 
42
+ def diarization_info(res):
43
+ starts = []
44
+ ends = []
45
+ speakers = []
46
+
47
+ for segment, track, _ in res.itertracks(yield_label=True):
48
+ starts.append(segment.start)
49
+ ends.append(segment.end)
50
+ speakers.append(track)
51
+
52
+ return starts, ends, speakers
53
+
54
+
55
+ def plot_diarization(starts, ends, speakers):
56
+ fig = go.Figure()
57
+
58
+ # Define a color map for different speakers
59
+ num_speakers = len(set(speakers))
60
+ colors = [f"hsl({h},80%,60%)" for h in np.linspace(0, 360, num_speakers)]
61
+
62
+ # Plot each segment with its speaker's color
63
+ for start, end, speaker in zip(starts, ends, speakers):
64
+ speaker_id = list(set(speakers)).index(speaker)
65
+ fig.add_trace(
66
+ go.Scatter(
67
+ x=[start, end],
68
+ y=[speaker_id, speaker_id],
69
+ mode="lines",
70
+ line=dict(color=colors[speaker_id], width=15),
71
+ showlegend=False,
72
+ )
73
+ )
74
+
75
+ fig.update_layout(
76
+ title="Speaker Diarization",
77
+ xaxis=dict(title="Time"),
78
+ yaxis=dict(title="Speaker"),
79
+ height=600,
80
+ width=800,
81
+ )
82
+
83
+ return fig
84
+
85
+
86
+ def transcribe_diarize(audio):
87
  sr, data = audio
88
  processed_data = np.array(data).astype(np.float32) / 32767.0
89
  waveform_tensor = torch.tensor(processed_data[np.newaxis, :])
 
94
  {"waveform": waveform_tensor, "sample_rate": sr}
95
  )
96
 
97
+ # Get diarization information
98
+ starts, ends, speakers = diarization_info(diarization_res)
99
 
100
+ # Plot diarization
101
+ diarization_plot = plot_diarization(starts, ends, speakers)
102
 
103
+ return transcription_res, diarization_res, diarization_plot
104
+
105
+
106
+ # creating the gradio interface
107
  demo = gr.Interface(
108
+ fn=transcribe_diarize,
109
  inputs=gr.Audio(sources=["upload", "microphone"]),
110
  outputs=[
111
+ gr.Textbox(lines=3, label="Text Transcription"),
112
+ gr.Textbox(label="Speaker Diarization"),
113
+ gr.Plot(),
114
  ],
115
+ title="Automatic Speech Recognition with Diarization 🗣️",
116
+ description="Transcribe your speech to text with distilled whisper and diarization with pyannote. Get started by recording from your mic or uploading an audio file 🎙️",
117
  )
118
 
119
+
120
  if __name__ == "__main__":
121
  demo.launch()
requirements.txt CHANGED
@@ -7,3 +7,4 @@ pyannote.database==5.0.1
7
  pyannote.metrics==3.2.1
8
  pyannote.pipeline==3.0.1
9
  python-dotenv==1.0.0
 
 
7
  pyannote.metrics==3.2.1
8
  pyannote.pipeline==3.0.1
9
  python-dotenv==1.0.0
10
+ plotly==5.18.0