import os
import sys
import numpy as np
import tensorflow as tf
import random
import cv2
from network2 import VGG16_Deconv
from skimage.io import imread, imshow
import math
import matplotlib.pyplot as plt
IMG_WIDTH = 224
IMG_HEIGHT = 224
IMG_CHANNELS = 3
SAVED_MODEL_PATH = 'saved_model_vgg16_test2/model.ckpt'
# Placeholders
X = tf.placeholder(tf.float32, [None, 224, 224, 3])
Y_ = tf.placeholder(tf.float32, [None, 224, 224])
sess = tf.Session()
# Load neural network architecture
logits = VGG16_Deconv(X, Y_)
# Load trained weights to the network
saver = tf.train.Saver()
saver.restore(sess, SAVED_MODEL_PATH)
# Load testing image
FILE_NAME = 'test7.JPG'
img = cv2.imread('demo_images/'+FILE_NAME, cv2.IMREAD_UNCHANGED)
img = img[...,[2,1,0]] # OpenCV BGR => RGB
original_dim = (int(len(img[0])/1), int(len(img)/1))
print(original_dim)
imshow(img)
# resize image
dim = (224,224) # VGG16 require an input of 224x224x3
resized = cv2.resize(img, dim, interpolation = cv2.INTER_AREA)
# Prepare data for network
test_image = np.reshape(resized, [-1, 224 , 224, 3])
test_data = {X:test_image}
# Make prediction
mask = sess.run([logits],feed_dict=test_data) # Get predicted mask from neural network
mask = np.reshape(np.squeeze(mask), [IMG_WIDTH , IMG_WIDTH, 1])
# Resize the mask
mask = cv2.resize(mask, original_dim, interpolation = cv2.INTER_AREA)
# Display result
imshow(mask)
mask = np.where(mask>0,1,0)
# Display result
imshow(mask)
AREA_fg = sum(sum(mask)) # foreground area
AREA_bg = len(mask)*len(mask[0]) # background area
print('fg/bg: '+str(int(100*AREA_fg/AREA_bg))+'%')
from PIL import Image
fg_mask_rgb = np.zeros([len(mask), len(mask[0]), 3])
fg_mask_rgb[..., 0] = mask[...]
fg_mask_rgb[..., 1] = mask[...]
fg_mask_rgb[..., 2] = mask[...]
bg_mask_rgb = np.where(fg_mask_rgb>0,0,1)
KERNEL_SIZE = 30
fg_mask_rgb = cv2.blur(fg_mask_rgb, (KERNEL_SIZE,KERNEL_SIZE))
#bg_mask_enhanced_rgb = cv2.blur(bg_mask_enhanced_rgb, (50,50))
foreground = (img*fg_mask_rgb)/225
background = cv2.blur((img*bg_mask_rgb)/225, (KERNEL_SIZE,KERNEL_SIZE))
outImage = cv2.add(foreground, background)
imshow(outImage)
#'''
cv2.imwrite('blur_results_test2/'+'_'+FILE_NAME, (outImage*225)[...,[2,1,0]]) # CAUTION!: 0-1 to 0-255; BGR to RGB
print('Image saved!')
#'''
imshow(img)