Add options to train and export TFLite compatible models#157
Add options to train and export TFLite compatible models#157GreenAppers wants to merge 3 commits intoaffinelayer:masterfrom
Conversation
|
I converted my model and now it only outputs empty black files. Any tips? |
|
You would have to retrain your model using the code from this PR. The issue is that the original implementation uses batch normalization with batch_size=1. This is a degenerate case, according to my understanding of the definitions, where batch normalization becomes instance normalization. However I can't speak to how Tensorflow implements these operations. At any rate, TFLite won't accept batch_normalization with batch_size=1 and requires instance_norm instead. This requires retraining. A smart conversion tool could probably keep all the weights, but change some ID field in the Tensorflow protobuf representing batch_norm to instance_norm. Here are the command lines I used to train and export a TFLite model: |
|
Do we need to add update_ops to train_op also? From documentation: |
Yes, tf.layers.batch_normalization is the same as tf.contrib.layers.instance_norm when batch_size=1 Test: max abs diff: 1.1920929e-07 |
|
Ahh-hah! Good to know, @mrgloom. Thanks. You can therefore convert the models using I don't have time now to update this PR with code to do the conversion. But a nice utility to convert a previously trained pix2pix model to a TFLite compatible one should be possible by loading the previously trained model, and a new model from this PR, and transferring the weights with |
These are the changes needed to get pix2pix-tensorflow running on mobile.
tf.layers.batch_normalization() on TFLite requires training=False, and that the model was trained with training=True and batch_size > 1. Also TFLite has no tf.tanh(), tf.image.convert_image_dtype(), or others.
Updating the batch_normalization Tensorflow variables for training=False (which aren't trainable vars) requires the UPDATE_OPS dependencies.
I have to re-train the model using tf.contrib.layers.instance_norm() instead of tf.layers.batch_normalization() with batch_size=1. It seems to work good. Is it the exact same thing?