Introduction

This chapter introduces how to implement a simple voiceprint recognition model using TensorFlow. First, you need to be familiar with audio classification. If you are not familiar with it, you can refer to this article: Implementing Sound Classification with TensorFlow. Based on this knowledge, we will train a voiceprint recognition model that can identify who is speaking, which can be applied in projects requiring audio verification.

Environment Preparation

This section mainly introduces the installation of librosa, PyAudio, and pydub. Other dependencies can be installed as needed.
- Python 3.7
- TensorFlow 2.0

Install librosa

The easiest way is to use the pip command:

pip install pytest-runner
pip install librosa

If the pip installation fails, install from source code. Download the source code from: https://github.com/librosa/librosa/releases/. For Windows, download the zip package for easy extraction:

pip install pytest-runner
tar xzf librosa-<version>.tar.gz  # or unzip librosa-<version>.tar.gz
cd librosa-<version>/
python setup.py install

If you encounter the error libsndfile64bit.dll': error 0x7e, specify version 0.6.3:

pip install librosa==0.6.3

Install PyAudio

Use pip to install:

pip install pyaudio

If compiling requires C++ libraries (Windows), download the wheel file from: https://github.com/intxcc/pyaudio_portaudio/releases.

Install pydub

Use pip to install:

pip install pydub

Data Preparation

This tutorial uses the Free ST Chinese Mandarin Corpus dataset, which contains 855 speakers with 102,600 audio files. You can mix other datasets if available.

If you have read the article on sound classification, you know that for small, numerous audio files, generating TFRecords is optimal. Create create_data.py to generate TFRecord files.

Step 1: Create Data List

Generate a list of audio files and labels in the format <audio_path>\t<label>. This list simplifies reading and can be extended for multiple datasets.

def get_data_list(audio_path, list_path):
    files = os.listdir(audio_path)

    f_train = open(os.path.join(list_path, 'train_list.txt'), 'w')
    f_test = open(os.path.join(list_path, 'test_list.txt'), 'w')

    sound_sum = 0
    s = set()
    for file in files:
        if '.wav' not in file:
            continue
        s.add(file[:15])  # Use first 15 characters as unique speaker ID
        sound_path = os.path.join(audio_path, file)
        if sound_sum % 100 == 0:
            f_test.write('%s\t%d\n' % (sound_path.replace('\\', '/'), len(s) - 1))
        else:
            f_train.write('%s\t%d\n' % (sound_path.replace('\\', '/'), len(s) - 1))
        sound_sum += 1

    f_test.close()
    f_train.close()

if __name__ == '__main__':
    get_data_list('dataset/ST-CMDS-20170001_1-OS', 'dataset')

Step 2: Convert to TFRecord

Convert audio files to Mel Spectrograms (using librosa.feature.melspectrogram), and split silent segments with librosa.effects.split.

import tensorflow as tf
import librosa
import numpy as np
import os
from tqdm import tqdm
import random

def _float_feature(value):
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _int64_feature(value):
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def data_example(data, label):
    feature = {
        'data': _float_feature(data),
        'label': _int64_feature(label),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

def create_data_tfrecord(data_list_path, save_path):
    with open(data_list_path, 'r') as f:
        data = f.readlines()
    with tf.io.TFRecordWriter(save_path) as writer:
        for d in tqdm(data):
            try:
                path, label = d.replace('\n', '').split('\t')
                wav, sr = librosa.load(path, sr=16000)
                intervals = librosa.effects.split(wav, top_db=20)
                wav_output = []
                for sliced in intervals:
                    wav_output.extend(wav[sliced[0]:sliced[1]])

                # Audio length: 16000 samples/sec * 2.04 sec = 32640 samples
                wav_len = int(16000 * 2.04)
                for i in range(20):  # Data augmentation: 20 random crops
                    if len(wav_output) > wav_len:
                        l = len(wav_output) - wav_len
                        r = random.randint(0, l)
                        wav_output = wav_output[r:wav_len + r]
                    else:
                        wav_output.extend(np.zeros(wav_len - len(wav_output), dtype=np.float32))
                    wav_output = np.array(wav_output)
                    # Convert to Mel Spectrogram (128x128)
                    ps = librosa.feature.melspectrogram(
                        y=wav_output, sr=sr, hop_length=256
                    ).reshape(-1).tolist()
                    if len(ps) != 128 * 128:  # Check shape
                        continue
                    tf_example = data_example(ps, int(label))
                    writer.write(tf_example.SerializeToString())
                    if len(wav_output) <= wav_len:
                        break
            except Exception as e:
                print(e)

if __name__ == '__main__':
    create_data_tfrecord('dataset/train_list.txt', 'dataset/train.tfrecord')
    create_data_tfrecord('dataset/test_list.txt', 'dataset/test.tfrecord')

Step 3: Create TFRecord Reader

Create reader.py to read TFRecord data. Adjust tf.io.FixedLenFeature if audio length is modified.

def _parse_data_function(example):
    data_feature_description = {
        'data': tf.io.FixedLenFeature([16384], tf.float32),  # 128*128
        'label': tf.io.FixedLenFeature([], tf.int64),
    }
    return tf.io.parse_single_example(example, data_feature_description)

def train_reader_tfrecord(data_path, num_epochs, batch_size):
    raw_dataset = tf.data.TFRecordDataset(data_path)
    return raw_dataset.map(_parse_data_function).shuffle(1000).repeat(num_epochs).batch(batch_size).prefetch(tf.data.AUTOTUNE)

def test_reader_tfrecord(data_path, batch_size):
    raw_dataset = tf.data.TFRecordDataset(data_path)
    return raw_dataset.map(_parse_data_function).batch(batch_size)

Model Training

Create train.py to train a ResNet50V2 model. The input shape (128, None, 1) adapts to variable-length Mel Spectrograms.

import tensorflow as tf
import reader
import librosa
import numpy as np
import os

class_dim = 855  # Number of speakers
EPOCHS = 500
BATCH_SIZE = 32
init_model = "models/model_weights.h5"

# Build model
model = tf.keras.Sequential([
    tf.keras.applications.ResNet50V2(
        include_top=False, weights=None, input_shape=(128, None, 1)
    ),
    tf.keras.layers.ActivityRegularization(l2=0.5),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.GlobalMaxPooling2D(),
    tf.keras.layers.Dense(class_dim, activation='softmax')
])

model.summary()

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
train_dataset = reader.train_reader_tfrecord('dataset/train.tfrecord', EPOCHS, BATCH_SIZE)
test_dataset = reader.test_reader_tfrecord('dataset/test.tfrecord', BATCH_SIZE)

if init_model:
    model.load_weights(init_model)

Train and validate every 200 batches:

for batch_id, data in enumerate(train_dataset):
    # Reshape Mel Spectrogram to (batch_size, 128, 128, 1)
    sounds = data['data'].numpy().reshape(-1, 128, 128, 1)
    labels = data['label']

    with tf.GradientTape() as tape:
        predictions = model(sounds)
        loss = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(labels, predictions))
        acc = tf.reduce_mean(tf.keras.metrics.sparse_categorical_accuracy(labels, predictions))

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    if batch_id % 20 == 0:
        print(f"Batch {batch_id}, Loss: {loss.numpy():.4f}, Acc: {acc.numpy():.4f}")

    if batch_id % 200 == 0 and batch_id != 0:
        # Validate
        test_losses, test_accs = [], []
        for d in test_dataset:
            test_sounds = d['data'].numpy().reshape(-1, 128, 128, 1)
            test_labels = d['label']
            test_preds = model(test_sounds)
            test_loss = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(test_labels, test_preds))
            test_acc = tf.reduce_mean(tf.keras.metrics.sparse_categorical_accuracy(test_labels, test_preds))
            test_losses.append(test_loss.numpy())
            test_accs.append(test_acc.numpy())
        print(f"Test Loss: {np.mean(test_losses):.4f}, Acc: {np.mean(test_accs):.4f}")
        model.save_weights('models/model_weights.h5')

Voiceprint Comparison

Create infer_contrast.py to extract features from the trained model (excluding the final classification layer). Use cosine similarity to compare voiceprints.

import tensorflow as tf
from tensorflow.keras.models import Model
import librosa
import numpy as np
import os

# Load intermediate layer (before classification)
layer_name = 'global_max_pooling2d'
model = tf.keras.models.load_model('models/resnet.h5')
intermediate_model = Model(inputs=model.input, outputs=model.get_layer(layer_name).output)

def load_data(data_path):
    wav, sr = librosa.load(data_path, sr=16000)
    intervals = librosa.effects.split(wav, top_db=20)
    wav_output = []
    for sliced in intervals:
        wav_output.extend(wav[sliced[0]:sliced[1]])
    if len(wav_output) < 8000:  # Minimum 0.5s
        raise Exception("Audio too short")
    wav_output = np.array(wav_output)
    ps = librosa.feature.melspectrogram(y=wav_output, sr=sr, hop_length=256).astype(np.float32)
    return ps[np.newaxis, ..., np.newaxis]  # Add batch and channel dims

def infer(audio_path):
    data = load_data(audio_path)
    return intermediate_model.predict(data)[0]

# Compare two voiceprints
if __name__ == '__main__':
    person1 = 'dataset/ST-CMDS-20170001_1-OS/20170001P00011A0001.wav'
    person2 = 'dataset/ST-CMDS-20170001_1-OS/20170001P00011I0081.wav'
    feat1 = infer(person1)
    feat2 = infer(person2)
    # Cosine similarity
    dist = np.dot(feat1, feat2) / (np.linalg.norm(feat1) * np.linalg.norm(feat2))
    print(f"Similarity: {dist:.4f}")

Voiceprint Recognition

Create infer_recognition.py to recognize speakers by comparing against a registered voice database.

import tensorflow as tf
from tensorflow.keras.models import Model
import librosa
import numpy as np
import os

layer_name = 'global_max_pooling2d'
model = tf.keras.models.load_model('models/resnet.h5')
intermediate_model = Model(inputs=model.input, outputs=model.get_layer(layer_name).output)

person_features = []
person_names = []

def load_data(data_path):
    wav, sr = librosa.load(data_path, sr=16000)
    intervals = librosa.effects.split(wav, top_db=20)
    wav_output = []
    for sliced in intervals:
        wav_output.extend(wav[sliced[0]:sliced[1]])
    if len(wav_output) < 8000:
        raise Exception("Audio too short")
    wav_output = np.array(wav_output)
    ps = librosa.feature.melspectrogram(y=wav_output, sr=sr, hop_length=256).astype(np.float32)
    return ps[np.newaxis, ..., np.newaxis]

def infer(audio_path):
    data = load_data(audio_path)
    return intermediate_model.predict(data)[0]

def load_audio_db(db_path):
    for file in os.listdir(db_path):
        path = os.path.join(db_path, file)
        name = file[:-4]
        person_features.append(infer(path))
        person_names.append(name)
        print(f"Loaded {name}")

def recognition(audio_path):
    max_sim = 0
    best_name = ""
    query_feat = infer(audio_path)
    for i, feat in enumerate(person_features):
        sim = np.dot(query_feat, feat) / (np.linalg.norm(query_feat) * np.linalg.norm(feat))
        if sim > max_sim:
            max_sim = sim
            best_name = person_names[i]
    return best_name, max_sim

# Demo: Record and recognize
if __name__ == '__main__':
    load_audio_db('audio_db')  # Pre-registered speakers
    # Simulate recording
    import pyaudio, wave
    p = pyaudio.PyAudio()
    stream = p.open(format=pyaudio.paInt16, channels=1, rate=16000, input=True, frames_per_buffer=1024)
    input("Press Enter to record...")
    frames = [stream.read(1024) for _ in range(16000*3//1024)]  # 3s
    wave.open("temp.wav", 'wb').writeframes(b''.join(frames))
    name, sim = recognition("temp.wav")
    print(f"Recognized: {name} (Similarity: {sim:.4f})")

GitHub Repository: https://github.com/yeyupiaoling/VoiceprintRecognition-Tensorflow/tree/master

Other Versions:
- PaddlePaddle: VoiceprintRecognition-PaddlePaddle
- PyTorch: VoiceprintRecognition-Pytorch
- Keras: VoiceprintRecognition-Keras

Xiaoye