Skip to content
Snippets Groups Projects
Commit 195254e9 authored by Bing Gong's avatar Bing Gong
Browse files

add lpip-tensorflow

parent 7f420174
Branches
Tags
No related merge requests found
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
[submodule "PerceptualSimilarity"]
path = PerceptualSimilarity
url = https://github.com/alexlee-gk/PerceptualSimilarity.git
BSD 2-Clause License
Copyright (c) 2018, alexlee-gk
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# lpips-tensorflow
Tensorflow port for the [PyTorch](https://github.com/richzhang/PerceptualSimilarity) implementation of the [Learned Perceptual Image Patch Similarity (LPIPS)](http://richzhang.github.io/PerceptualSimilarity/) metric.
This is done by exporting the model from PyTorch to ONNX and then to TensorFlow.
## Getting started
### Installation
- Clone this repo.
```bash
git clone https://github.com/alexlee-gk/lpips-tensorflow.git
cd lpips-tensorflow
```
- Install TensorFlow and dependencies from http://tensorflow.org/
- Install other dependencies.
```bash
pip install -r requirements.txt
```
### Using the LPIPS metric
The `lpips` TensorFlow function works with individual images or batches of images.
It also works with images of any spatial dimensions (but the dimensions should be at least the size of the network's receptive field).
This example computes the LPIPS distance between batches of images.
```python
import numpy as np
import tensorflow as tf
import lpips_tf
batch_size = 32
image_shape = (batch_size, 64, 64, 3)
image0 = np.random.random(image_shape)
image1 = np.random.random(image_shape)
image0_ph = tf.placeholder(tf.float32)
image1_ph = tf.placeholder(tf.float32)
distance_t = lpips_tf.lpips(image0_ph, image1_ph, model='net-lin', net='alex')
with tf.Session() as session:
distance = session.run(distance_t, feed_dict={image0_ph: image0, image1_ph: image1})
```
## Exporting additional models
### Export PyTorch model to TensorFlow through ONNX
- Clone the PerceptualSimilarity submodule and add it to the PYTHONPATH.
```bash
git submodule update --init --recursive
export PYTHONPATH=PerceptualSimilarity:$PYTHONPATH
```
- Install more dependencies.
```bash
pip install -r requirements-dev.txt
```
- Export the model to ONNX *.onnx and TensorFlow *.pb files in the `models` directory.
```bash
python export_to_tensorflow.py --model net-lin --net alex
```
### Known issues
- The SqueezeNet model cannot be exported since ONNX cannot export one of the operators.
import argparse
import os
import onnx
import torch
import torch.onnx
from models import dist_model as dm
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', choices=['net-lin', 'net'], default='net-lin', help='net-lin or net')
parser.add_argument('--net', choices=['squeeze', 'alex', 'vgg'], default='alex', help='squeeze, alex, or vgg')
parser.add_argument('--version', type=str, default='0.1')
parser.add_argument('--image_height', type=int, default=64)
parser.add_argument('--image_width', type=int, default=64)
args = parser.parse_args()
model = dm.DistModel()
model.initialize(model=args.model, net=args.net, use_gpu=False, version=args.version)
print('Model [%s] initialized' % model.name())
dummy_im0 = torch.Tensor(1, 3, args.image_height, args.image_width) # image should be RGB, normalized to [-1, 1]
dummy_im1 = torch.Tensor(1, 3, args.image_height, args.image_width)
cache_dir = os.path.expanduser('~/.lpips')
os.makedirs(cache_dir, exist_ok=True)
onnx_fname = os.path.join(cache_dir, '%s_%s_v%s.onnx' % (args.model, args.net, args.version))
# export model to onnx format
torch.onnx.export(model.net, (dummy_im0, dummy_im1), onnx_fname, verbose=True)
# load and change dimensions to be dynamic
model = onnx.load(onnx_fname)
for dim in (0, 2, 3):
model.graph.input[0].type.tensor_type.shape.dim[dim].dim_param = '?'
model.graph.input[1].type.tensor_type.shape.dim[dim].dim_param = '?'
# needs to be imported after all the pytorch stuff, otherwise this causes a segfault
from onnx_tf.backend import prepare
tf_rep = prepare(model, device='CPU')
producer_version = tf_rep.graph.graph_def_versions.producer
pb_fname = os.path.join(cache_dir, '%s_%s_v%s_%d.pb' % (args.model, args.net, args.version, producer_version))
tf_rep.export_graph(pb_fname)
input0_name, input1_name = [tf_rep.tensor_dict[input_name].name for input_name in tf_rep.inputs]
(output_name,) = [tf_rep.tensor_dict[output_name].name for output_name in tf_rep.outputs]
# ensure these are the names of the 2 inputs, since that will be assumed when loading the pb file
assert input0_name == '0:0'
assert input1_name == '1:0'
# ensure that the only output is the output of the last op in the graph, since that will be assumed later
(last_output_name,) = [output.name for output in tf_rep.graph.get_operations()[-1].outputs]
assert output_name == last_output_name
if __name__ == '__main__':
main()
import os
import sys
import tensorflow as tf
from six.moves import urllib
_URL = 'http://rail.eecs.berkeley.edu/models/lpips'
def _download(url, output_dir):
"""Downloads the `url` file into `output_dir`.
Modified from https://github.com/tensorflow/models/blob/master/research/slim/datasets/dataset_utils.py
"""
filename = url.split('/')[-1]
filepath = os.path.join(output_dir, filename)
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(url, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
def lpips(input0, input1, model='net-lin', net='alex', version=0.1):
"""
Learned Perceptual Image Patch Similarity (LPIPS) metric.
Args:
input0: An image tensor of shape `[..., height, width, channels]`,
with values in [0, 1].
input1: An image tensor of shape `[..., height, width, channels]`,
with values in [0, 1].
Returns:
The Learned Perceptual Image Patch Similarity (LPIPS) distance.
Reference:
Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang.
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric.
In CVPR, 2018.
"""
# flatten the leading dimensions
batch_shape = tf.shape(input0)[:-3]
input0 = tf.reshape(input0, tf.concat([[-1], tf.shape(input0)[-3:]], axis=0))
input1 = tf.reshape(input1, tf.concat([[-1], tf.shape(input1)[-3:]], axis=0))
# NHWC to NCHW
input0 = tf.transpose(input0, [0, 3, 1, 2])
input1 = tf.transpose(input1, [0, 3, 1, 2])
# normalize to [-1, 1]
input0 = input0 * 2.0 - 1.0
input1 = input1 * 2.0 - 1.0
input0_name, input1_name = '0:0', '1:0'
default_graph = tf.get_default_graph()
producer_version = default_graph.graph_def_versions.producer
cache_dir = os.path.expanduser('~/.lpips')
os.makedirs(cache_dir, exist_ok=True)
# files to try. try a specific producer version, but fallback to the version-less version (latest).
pb_fnames = [
'%s_%s_v%s_%d.pb' % (model, net, version, producer_version),
'%s_%s_v%s.pb' % (model, net, version),
]
for pb_fname in pb_fnames:
if not os.path.isfile(os.path.join(cache_dir, pb_fname)):
try:
_download(os.path.join(_URL, pb_fname), cache_dir)
except urllib.error.HTTPError:
pass
if os.path.isfile(os.path.join(cache_dir, pb_fname)):
break
with open(os.path.join(cache_dir, pb_fname), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def,
input_map={input0_name: input0, input1_name: input1})
distance, = default_graph.get_operations()[-1].outputs
if distance.shape.ndims == 4:
distance = tf.squeeze(distance, axis=[-3, -2, -1])
# reshape the leading dimensions
distance = tf.reshape(distance, batch_shape)
return distance
torch>=0.4.0
torchvision>=0.2.1
onnx
onnx-tf
numpy
six
#!/usr/bin/env python
from distutils.core import setup
setup(
name='lpips-tf',
description='Tensorflow port for the Learned Perceptual Image Patch Similarity (LPIPS) metric',
author='Alex Lee',
url='https://github.com/alexlee-gk/lpips-tensorflow/',
py_modules=['lpips_tf']
)
import argparse
import cv2
import numpy as np
import tensorflow as tf
import lpips_tf
def load_image(fname):
image = cv2.imread(fname)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image.astype(np.float32) / 255.0
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', choices=['net-lin', 'net'], default='net-lin', help='net-lin or net')
parser.add_argument('--net', choices=['squeeze', 'alex', 'vgg'], default='alex', help='squeeze, alex, or vgg')
parser.add_argument('--version', type=str, default='0.1')
args = parser.parse_args()
ex_ref = load_image('./PerceptualSimilarity/imgs/ex_ref.png')
ex_p0 = load_image('./PerceptualSimilarity/imgs/ex_p0.png')
ex_p1 = load_image('./PerceptualSimilarity/imgs/ex_p1.png')
session = tf.Session()
image0_ph = tf.placeholder(tf.float32)
image1_ph = tf.placeholder(tf.float32)
lpips_fn = session.make_callable(
lpips_tf.lpips(image0_ph, image1_ph, model=args.model, net=args.net, version=args.version),
[image0_ph, image1_ph])
ex_d0 = lpips_fn(ex_ref, ex_p0)
ex_d1 = lpips_fn(ex_ref, ex_p1)
print('Distances: (%.3f, %.3f)' % (ex_d0, ex_d1))
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment