Skip to content

selflein/PriorNetworks

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

364 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PriorNetworks

Updates for this fork

  • Reproduce the experiments on the synthetic dataset in Section 5.1 of the PriorNetworks paper.
  • Add environment.yml file
  • Commands for training on image datasets
  • Rewrite evaluation and add new datasets

Experiment reproduction results


Setup

Install conda environment

conda create -f environment.yml

Install uncertainty eval libary

pip install git+https://github.com/selflein/nn_uncertainty_eval

Train on the Gaussian synthetic dataset

python prior_networks/priornet/run/train_dpn_synth.py --checkpoint_path checkpoints/dpn_synth_std4 --model_dir checkpoints/dpn_synth_std4 --lrc 10 --lrc 30 --n_epochs 1000 --lr 1e-4 --gpu 0 --std 4

the values of {std} (which is the variance of the Gaussians in the dataset) in the paper are 1 and 4.

Alternatively, pretrained models can be found here. Place the checkpoint folder in the root of the repo.

Generate figures

Run the reproduce_gaussian_exp.ipynb notebook to generate the plots found above.

Train on image datasets

  • Setup model

    python prior_networks/priornet/setup/setup_dpn.py out/model/wideresnet28_10 32 10 --drop_rate 0.3 --arch wide_resnet28_10
    
  • Train (Prior Network)

    python prior_networks/priornet/run/train_dpn.py ../data/ CIFAR10 CIFAR100 45 5e-4 --optimizer ADAM --normalize --rotate --augment --gamma 1. --reverse_KL False --model_dir out/model/wideresnet_28_10_forward_kl_cifar_10_cifar_100 --checkpoint_path out/model/wideresnet_28_10_forward_kl_cifar_10_cifar_100 --gpu 0 --lrc 30
    
  • Train (Reverse Prior Network)

    python prior_networks/priornet/run/train_dpn.py ../data/ CIFAR10 CIFAR100 50 1e-4 --optimizer ADAM --normalize --rotate --augment --gamma 1. --reverse_KL True --model_dir out/model/wideresnet_28_10_reverse_kl_cifar_10_cifar_100 --checkpoint_path out/model wideresnet_28_10_reverse_kl_cifar_10_cifar_100 --gpu 0 --lrc 50
    

OOD detection on image datasets

python prior_networks/priornet/run/ood_detect.py ../data CIFAR10 --ood_dataset SVHN --ood_dataset LSUN --ood_dataset Noise --ood_dataset SVHN_unscaled out/evals_new/reverse_all --model_dir out/model wideresnet28_10_reverse_kl_cifar_10_cifar_100_new/model/model.tar --gpu 0

About

Fork reproducing results on the synthetic dataset presented in the paper "Predictive Uncertainty Estimation via Prior Networks".

Resources

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages

  • Python 98.9%
  • Jupyter Notebook 1.1%