Audio Classification with PyTorch¶
Introduction¶
This project is a sound classification system based on PyTorch, aiming to recognize various environmental sounds, animal calls, and languages. It provides multiple sound classification models such as EcapaTdnn, PANNS, ResNetSE, CAMPPlus, and ERes2Net to support different application scenarios. Additionally, the project offers test reports for the commonly used Urbansound8K dataset and examples for downloading and using some dialect datasets. Users can select suitable models and datasets based on their needs for more accurate sound classification. The application scenarios are extensive, including outdoor environmental monitoring, wildlife protection, and speech recognition. The project also encourages users to explore more use cases to promote the development and application of sound classification technology.
Source Code: AudioClassification-Pytorch
Prerequisites¶
- Anaconda 3
- Python 3.11
- Pytorch 2.0.1
- Windows 11 or Ubuntu 22.04
Project Features¶
- Supported Models: EcapaTdnn, PANNS, TDNN, Res2Net, ResNetSE, CAMPPlus, ERes2Net
- Supported Pooling Layers: AttentiveStatsPool(ASP), SelfAttentivePooling(SAP), TemporalStatisticsPooling(TSP), TemporalAveragePooling(TAP)
- Supported Preprocessing Methods: MelSpectrogram, Spectrogram, MFCC, Fbank
Model Papers:
- EcapaTdnn: ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification
- PANNS: PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
- TDNN: Prediction of speech intelligibility with DNN-based performance measures
- Res2Net: Res2Net: A New Multi-scale Backbone Architecture
- ResNetSE: Squeeze-and-Excitation Networks
- CAMPPlus: CAM++: A Fast and Efficient Network for Speaker Verification Using Context-Aware Masking
- ERes2Net: An Enhanced Res2Net with Local and Global Feature Fusion for Speaker Verification
Model Test Results¶
| Model | Params(M) | Preprocessing | Dataset | Number of Classes | Accuracy |
|---|---|---|---|---|---|
| ResNetSE | 7.8 | Flank | UrbanSound8K | 10 | 0.98863 |
| CAMPPlus | 7.1 | Flank | UrbanSound8K | 10 | 0.97727 |
| ERes2Net | 6.6 | Flank | UrbanSound8K | 10 | 0.96590 |
| PANNS (CNN10) | 5.2 | Flank | UrbanSound8K | 10 | 0.96590 |
| Res2Net | 5.0 | Flank | UrbanSound8K | 10 | 0.94318 |
| TDNN | 2.6 | Flank | UrbanSound8K | 10 | 0.92045 |
| EcapaTdnn | 6.1 | Flank | UrbanSound8K | 10 | 0.91876 |
Installation¶
Install PyTorch (GPU version)¶
If you haven’t installed PyTorch yet:
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
Install macls Library¶
Install using pip:
python -m pip install macls -U -i https://pypi.tuna.tsinghua.edu.cn/simple
Recommended Source Installation (for latest code):
git clone https://github.com/yeyupiaoling/AudioClassification-Pytorch.git
cd AudioClassification-Pytorch/
python setup.py install
Data Preparation¶
Generate data lists for subsequent reading. The audio_path is the path to audio files. Users should place the audio dataset in the dataset/audio directory, with each folder containing audio data for one category. Each audio file should be at least 3 seconds long, e.g., dataset/audio/birdcalls/.... The audio list will be stored in the dataset directory, with each line formatted as audio_path\tlabel.
For the Urbansound8K dataset (widely used for urban sound classification):
1. Download the dataset from UrbanSound8K.tar.gz
2. Extract it to the dataset directory
3. Use the following code to generate the data list (modify create_data.py as needed):
python create_data.py
The generated data list will look like this:
dataset/UrbanSound8K/audio/fold2/104817-4-0-2.wav 4
dataset/UrbanSound8K/audio/fold9/105029-7-2-5.wav 7
dataset/UrbanSound8K/audio/fold3/107228-5-0-0.wav 5
dataset/UrbanSound8K/audio/fold4/109711-3-2-4.wav 3
Modify Preprocessing (Optional)¶
The default preprocessing method is Fbank. To change it, modify the configuration file:
# Data Preprocessing Parameters
preprocess_conf:
use_hf_model: False # Use HuggingFace model for feature extraction
feature_method: 'Fbank' # When use_hf_model=False: MelSpectrogram, Spectrogram, MFCC, Fbank
method_args:
sample_frequency: 16000
num_mel_bins: 80
Feature Extraction (Optional)¶
Feature extraction can be time-consuming. To speed up training, pre-extract features:
- Run
extract_features.py:
python extract_features.py --configs=configs/cam++.yml --save_dir=dataset/features
- Modify the configuration file to use the pre-extracted features:
dataset_conf:
train_list: dataset/train_list_features.txt
test_list: dataset/test_list_features.txt
Training¶
# Single GPU training
CUDA_VISIBLE_DEVICES=0 python train.py
# Multi-GPU training
CUDA_VISIBLE_DEVICES=0,1 torchrun --standalone --nnodes=1 --nproc_per_node=2 train.py
Training Log Example:
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:14 - ----------- Additional Configuration -----------
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - configs: configs/ecapa_tdnn.yml
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - local_rank: 0
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - pretrained_model: None
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - resume_model: None
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - save_model_path: models/
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - use_gpu: True
...
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:22 - train_conf:
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:29 - log_interval: 10
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:29 - max_epoch: 30
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:31 - use_model: EcapaTdnn
Training Visualization:
visualdl --logdir=log --host=0.0.0.0
Access http://localhost:8040/ in your browser to view training metrics.
Evaluation¶
python eval.py --configs=configs/bi_lstm.yml
Evaluation Output:
[2024-02-03 15:13:25.469242 INFO ] trainer:evaluate:461 - Model successfully loaded: models/CAMPPlus_Fbank/best_model/model.pth
100%|██████████████████████████████| 150/150 [00:00<00:00, 1281.96it/s]
Evaluation Time: 1s, Loss: 0.61840, Accuracy: 0.87333
Prediction¶
python infer.py --audio_path=dataset/UrbanSound8K/audio/fold5/156634-5-2-5.wav
Additional Features¶
- Record Audio: Use
record_audio.pyto record audio with a sampling rate of 16000, mono, 16-bit:
python record_audio.py
- Real-time Recognition: Use
infer_record.pyfor continuous audio recording and recognition:
python infer_record.py --record_seconds=3
References¶
- https://github.com/PaddlePaddle/PaddleSpeech
- https://github.com/yeyupiaoling/PaddlePaddle-MobileFaceNets
- https://github.com/yeyupiaoling/PPASR
- https://github.com/alibaba-damo-academy/3D-Speaker