Skip to content

fdenoodt/Smooth-InfoMax

Repository files navigation

Smooth InfoMax

This repository contains the code for the paper:

Fabian Denoodt, Bart de Boer and José Oramas - Smooth InfoMax - Towards easier Post-Hoc Interpretability

GIF_short

Abstract

We introduce Smooth InfoMax (SIM), a self-supervised representation learning method that incorporates interpretability constraints into the latent representations at different depths of the network. Based on $\beta$-VAEs, SIM's architecture consists of probabilistic modules optimized locally with the InfoNCE loss to produce Gaussian-distributed representations regularized toward the standard normal distribution. This creates smooth, well-defined, and better-disentangled latent spaces, enabling easier post-hoc analysis. Evaluated on speech data, SIM preserves the large-scale training benefits of Greedy InfoMax while improving the effectiveness of post-hoc interpretability methods across layers.

image-20230613111315897

Running the code and reproducing the experiments

LibriSpeech dataset

  • Download LibriSpeech
        mkdir datasets
        cd datasets || exit
        wget http://www.openslr.org/resources/12/train-clean-100.tar.gz
        tar -xzf train-clean-100.tar.gz || exit
        mkdir LibriSpeech100_labels_split
        cd LibriSpeech100_labels_split || exit
        gdown https://drive.google.com/uc?id=1vSHmncPsRY7VWWAd_BtoWs9-fQ5cBrEB # test split
        gdown https://drive.google.com/uc?id=1ubREoLQu47_ZDn39YWv1wvVPe2ZlIZAb # train split
        gdown https://drive.google.com/uc?id=1bLuDkapGBERG_VYPS7fNZl5GXsQ9z3p2 # converted_aligned_phones.zip
        unzip converted_aligned_phones.zip
        cd ../..
  • Set required variables
    # Required setup (overwrites hyperparameters and sets up wandb for logging)
    
    # By default these args will run SIM. If you want to run GIM instead, change `sim_audio_de_boer_distr_true` to `sim_audio_de_boer_distr_false`
    override='./logs sim_audio_de_boer_distr_true \
    --overrides encoder_config.dataset.num_workers=4 speakers_classifier_config.dataset.num_workers=4 phones_classifier_config.dataset.num_workers=4 decoder_config.dataset.num_workers=4 \
    encoder_config.dataset.dataset=1 phones_classifier_config.dataset.dataset=1 speakers_classifier_config.dataset.dataset=1 decoder_config.dataset.dataset=1 \
    encoder_config.dataset.batch_size=8 decoder_config.dataset.batch_size=8 speakers_classifier_config.dataset.batch_size=8 phones_classifier_config.dataset.batch_size=8 \
    seed=4 encoder_config.num_epochs=100 speakers_classifier_config.encoder_num=99 phones_classifier_config.encoder_num=99 decoder_config.encoder_num=99 decoder_config.num_epochs=50 \
    phones_classifier_config.num_epochs=10 speakers_classifier_config.num_epochs=10 speakers_classifier_config.gradient_clipping=1.0 phones_classifier_config.gradient_clipping=1.0 \
    encoder_config.kld_weight=0.001 \
    wandb_project_name=TODO wandb_entity=TODO '; # Please update this line in
    
    # Log into WandB
    wandb login XXXXX-WANDB-KEY-PLEASE-USE-YOUR-OWN-XXXX;
  • Run SIM or GIM
    python -m main $override;
  • Run the classifiers
    echo 'Training classifier - speakers'; 
    python -m linear_classifiers.logistic_regression_speaker $override \
        speakers_classifier_config.dataset.dataset=1 \
        speakers_classifier_config.bias=True \
        encoder_config.deterministic=True;
    
    echo 'Training classifier - phones'; 
    python -m linear_classifiers.logistic_regression_phones $override \
        phones_classifier_config.dataset.dataset=1 \
        speakers_classifier_config.bias=True \
        encoder_config.deterministic=True;
  • Run the interpretability analysis
    # Perform speaker classification with different encoder settings
    python -m linear_classifiers.logistic_regression_speaker $override \
        encoder_config.deterministic=False \
        speakers_classifier_config.bias=False \
        speakers_classifier_config.encoder_module=0 \
        speakers_classifier_config.encoder_layer=-1;
    
    python -m post_hoc_analysis.interpretability.main_speakers_analysis $override \
        encoder_config.deterministic=False \
        speakers_classifier_config.bias=False \
        speakers_classifier_config.encoder_module=0 \
        speakers_classifier_config.encoder_layer=-1;
    
    python -m linear_classifiers.logistic_regression_speaker $override \
        encoder_config.deterministic=False \
        speakers_classifier_config.bias=False \
        speakers_classifier_config.encoder_module=1 \
        speakers_classifier_config.encoder_layer=-1;
    
    python -m post_hoc_analysis.interpretability.main_speakers_analysis $override \
        encoder_config.deterministic=False \
        speakers_classifier_config.bias=False \
        speakers_classifier_config.encoder_module=1 \
        speakers_classifier_config.encoder_layer=-1;
    
    python -m linear_classifiers.logistic_regression_speaker $override \
        encoder_config.deterministic=False \
        speakers_classifier_config.bias=False \
        speakers_classifier_config.encoder_module=2 \
        speakers_classifier_config.encoder_layer=-1;
    
    python -m post_hoc_analysis.interpretability.main_speakers_analysis $override \
        encoder_config.deterministic=False \
        speakers_classifier_config.bias=False \
        speakers_classifier_config.encoder_module=2 \
        speakers_classifier_config.encoder_layer=-1;
    
    # Train decoder with different encoder modules
    python -m decoder.train_decoder $override \
        decoder_config.decoder_loss=0 \
        decoder_config.dataset.dataset=1 \
        decoder_config.encoder_module=0 \
        decoder_config.encoder_layer=-1;
    
    python -m decoder.train_decoder $override \
        decoder_config.decoder_loss=0 \
        decoder_config.dataset.dataset=1 \
        decoder_config.encoder_module=1 \
        decoder_config.encoder_layer=-1;
    
    python -m decoder.train_decoder $override \
        decoder_config.decoder_loss=0 \
        decoder_config.dataset.dataset=1 \
        decoder_config.encoder_module=2 \
        decoder_config.encoder_layer=-1;

Artificial Speech dataset

  • Download the dataset
     git clone https://github.com/fdenoodt/Artificial-Speech-Dataset
     cp -r Artificial-Speech-Dataset/* datasets/corpus/
  • Set required variables
    # Required setup (overwrites hyperparameters and sets up wandb for logging)
      
    # By default these args will run SIM. If you want to run GIM instead, change `sim_audio_de_boer_distr_true` to `sim_audio_de_boer_distr_false`
    override='./logs sim_audio_de_boer_distr_true \
    --overrides \
    encoder_config.dataset.num_workers=4 \
    syllables_classifier_config.dataset.num_workers=4 \
    decoder_config.dataset.num_workers=4 \
    encoder_config.use_batch_norm=False \
    use_wandb=True \
    wandb_project_name=TODO wandb_entity=TODO '; # Please update this line
      
    # Log into WandB
    wandb login XXXXX-WANDB-KEY-PLEASE-USE-YOUR-OWN-XXXX;
  • Run SIM or GIM
    python -m main $override;
  • Run the classifiers
    echo 'Training classifier - syllables'; 
    python -m linear_classifiers.logistic_regression $override \
        syllables_classifier_config.bias=True \
        syllables_classifier_config.dataset.labels=syllables \
        encoder_config.deterministic=True;
    
    echo 'Training classifier - vowels'; 
    python -m linear_classifiers.logistic_regression $override \
        syllables_classifier_config.bias=True \
        syllables_classifier_config.dataset.labels=vowels \
        encoder_config.deterministic=True;
  • Run the interpretability analysis
    # Perform vowel classification with different encoder settings
    python -m linear_classifiers.logistic_regression $override \
        encoder_config.deterministic=False \
        syllables_classifier_config.bias=False \
        syllables_classifier_config.dataset.labels=vowels \
        syllables_classifier_config.encoder_module=0 \
        syllables_classifier_config.encoder_layer=-1;
    
    echo 'vowel weight plots'; 
    python -m post_hoc_analysis.interpretability.main_vowel_classifier_analysis $override \
        encoder_config.deterministic=False \
        syllables_classifier_config.bias=False \
        syllables_classifier_config.dataset.labels=vowels \
        syllables_classifier_config.encoder_module=0 \
        syllables_classifier_config.encoder_layer=-1;
    
    python -m linear_classifiers.logistic_regression $override \
        encoder_config.deterministic=False \
        syllables_classifier_config.bias=False \
        syllables_classifier_config.dataset.labels=vowels \
        syllables_classifier_config.encoder_module=1 \
        syllables_classifier_config.encoder_layer=-1;
    
    echo 'vowel weight plots'; 
    python -m post_hoc_analysis.interpretability.main_vowel_classifier_analysis $override \
        encoder_config.deterministic=False \
        syllables_classifier_config.bias=False \
        syllables_classifier_config.dataset.labels=vowels \
        syllables_classifier_config.encoder_module=1 \
        syllables_classifier_config.encoder_layer=-1;
    
    python -m linear_classifiers.logistic_regression $override \
        encoder_config.deterministic=False \
        syllables_classifier_config.bias=False \
        syllables_classifier_config.dataset.labels=vowels \
        syllables_classifier_config.encoder_module=2 \
        syllables_classifier_config.encoder_layer=-1;
    
    echo 'vowel weight plots'; 
    python -m post_hoc_analysis.interpretability.main_vowel_classifier_analysis $override \
        encoder_config.deterministic=False \
        syllables_classifier_config.bias=False \
        syllables_classifier_config.dataset.labels=vowels \
        syllables_classifier_config.encoder_module=2 \
        syllables_classifier_config.encoder_layer=-1;
    
    # Train decoder with different encoder modules
    python -m decoder.train_decoder $override \
        decoder_config.decoder_loss=0 \
        decoder_config.dataset.dataset=4 \
        decoder_config.encoder_module=0 \
        decoder_config.encoder_layer=-1;
    
    python -m decoder.train_decoder $override \
        decoder_config.decoder_loss=0 \
        decoder_config.dataset.dataset=4 \
        decoder_config.encoder_module=1 \
        decoder_config.encoder_layer=-1;
    
    python -m decoder.train_decoder $override \
        decoder_config.decoder_loss=0 \
        decoder_config.dataset.dataset=4 \
        decoder_config.encoder_module=2 \
        decoder_config.encoder_layer=-1;

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages