Ajuda a proteger a Grande Barreira de Corais com TensorFlow em Kaggle Junte Desafio

Transferência de estilo artístico com TensorFlow Lite

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno Veja o modelo TF Hub

Um dos mais excitantes desenvolvimentos na aprendizagem profunda para sair recentemente é transferência artística estilo , ou a capacidade de criar uma nova imagem, conhecido como um pastiche , baseado em duas imagens de entrada: um representando o estilo artístico e um representando o conteúdo.

Exemplo de transferência de estilo

Usando essa técnica, podemos gerar belas novas obras de arte em uma variedade de estilos.

Exemplo de transferência de estilo

Se você é novo no TensorFlow Lite e está trabalhando com o Android, recomendamos explorar os seguintes aplicativos de exemplo que podem ajudá-lo a começar.

Exemplo Android exemplo iOS

Se você estiver usando uma plataforma diferente do Android ou iOS, ou você já está familiarizado com os Lite APIs TensorFlow , você pode seguir este tutorial para aprender como aplicar a transferência de estilo em qualquer par de conteúdo e imagem do estilo com um pré-treinados TensorFlow Lite modelo. Você pode usar o modelo para adicionar transferência de estilo aos seus próprios aplicativos móveis.

O modelo é open-source no GitHub . Você pode retreinar o modelo com parâmetros diferentes (por exemplo, aumentar os pesos das camadas de conteúdo para fazer a imagem de saída parecer mais com a imagem de conteúdo).

Compreenda a arquitetura do modelo

Arquitetura do Modelo

Este modelo de transferência de estilo artístico consiste em dois submodelos:

  1. Estilo Prediciton Modelo: A MobilenetV2 baseado em rede neural que leva uma imagem estilo de entrada para um estilo de gargalo vector 100-dimensão.
  2. Estilo Transform Modelo: Uma rede neural que leva aplicar um vector estilo gargalo para uma imagem de conteúdo e cria uma imagem estilizada.

Se seu aplicativo só precisa oferecer suporte a um conjunto fixo de imagens de estilo, você pode calcular seus vetores de gargalo de estilo com antecedência e excluir o modelo de previsão de estilo do binário do seu aplicativo.

Configurar

Importar dependências.

import tensorflow as tf
print(tf.__version__)
2.6.0
import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import numpy as np
import time
import functools

Faça o download do conteúdo, das imagens de estilo e dos modelos pré-treinados do TensorFlow Lite.

content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')
style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')

style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')
style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg
458752/458481 [==============================] - 0s 0us/step
466944/458481 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg
114688/108525 [===============================] - 0s 0us/step
122880/108525 [=================================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite
2834432/2828838 [==============================] - 0s 0us/step
2842624/2828838 [==============================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite
286720/284398 [==============================] - 0s 0us/step
294912/284398 [===============================] - 0s 0us/step

Pré-processar as entradas

  • A imagem de conteúdo e a imagem de estilo devem ser imagens RGB com valores de pixel sendo números float32 entre [0..1].
  • O tamanho da imagem do estilo deve ser (1, 256, 256, 3). Cortamos a imagem centralmente e a redimensionamos.
  • A imagem do conteúdo deve ser (1, 384, 384, 3). Recortamos a imagem centralmente e a redimensionamos.
# Function to load an image from a file, and add a batch dimension.
def load_img(path_to_img):
  img = tf.io.read_file(path_to_img)
  img = tf.io.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)
  img = img[tf.newaxis, :]

  return img

# Function to pre-process by resizing an central cropping it.
def preprocess_image(image, target_dim):
  # Resize the image so that the shorter dimension becomes 256px.
  shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
  short_dim = min(shape)
  scale = target_dim / short_dim
  new_shape = tf.cast(shape * scale, tf.int32)
  image = tf.image.resize(image, new_shape)

  # Central crop the image.
  image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)

  return image

# Load the input images.
content_image = load_img(content_path)
style_image = load_img(style_path)

# Preprocess the input images.
preprocessed_content_image = preprocess_image(content_image, 384)
preprocessed_style_image = preprocess_image(style_image, 256)

print('Style Image Shape:', preprocessed_style_image.shape)
print('Content Image Shape:', preprocessed_content_image.shape)
Style Image Shape: (1, 256, 256, 3)
Content Image Shape: (1, 384, 384, 3)

Visualize as entradas

def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

plt.subplot(1, 2, 1)
imshow(preprocessed_content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(preprocessed_style_image, 'Style Image')

png

Execute a transferência de estilo com TensorFlow Lite

Predição de estilo

# Function to run style prediction on preprocessed style image.
def run_style_predict(preprocessed_style_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_predict_path)

  # Set model input.
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  interpreter.set_tensor(input_details[0]["index"], preprocessed_style_image)

  # Calculate style bottleneck.
  interpreter.invoke()
  style_bottleneck = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return style_bottleneck

# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(preprocessed_style_image)
print('Style Bottleneck Shape:', style_bottleneck.shape)
Style Bottleneck Shape: (1, 1, 1, 100)

Transformação de estilo

# Run style transform on preprocessed style image
def run_style_transform(style_bottleneck, preprocessed_content_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_transform_path)

  # Set model input.
  input_details = interpreter.get_input_details()
  interpreter.allocate_tensors()

  # Set model inputs.
  interpreter.set_tensor(input_details[0]["index"], preprocessed_content_image)
  interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
  interpreter.invoke()

  # Transform content image.
  stylized_image = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return stylized_image

# Stylize the content image using the style bottleneck.
stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)

# Visualize the output.
imshow(stylized_image, 'Stylized Image')

png

Mistura de estilo

Podemos misturar o estilo da imagem do conteúdo na saída estilizada, o que, por sua vez, torna a saída mais parecida com a imagem do conteúdo.

# Calculate style bottleneck of the content image.
style_bottleneck_content = run_style_predict(
    preprocess_image(content_image, 256)
    )
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5

# Blend the style bottleneck of style image and content image
style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \
                           + (1 - content_blending_ratio) * style_bottleneck

# Stylize the content image using the style bottleneck.
stylized_image_blended = run_style_transform(style_bottleneck_blended,
                                             preprocessed_content_image)

# Visualize the output.
imshow(stylized_image_blended, 'Blended Stylized Image')

png

Benchmarks de desempenho

Números de benchmark de desempenho são gerados com a ferramenta descrita aqui .

Nome do modelo Tamanho do modelo Dispositivo NNAPI CPU GPU
Modelo de predição de estilo (int8) 2,8 Mb Pixel 3 (Android 10) 142ms 14ms
Pixel 4 (Android 10) 5,2 ms 6,7ms
iPhone XS (iOS 12.4.1) 10,7ms
Modelo de transformação de estilo (int8) 0,2 Mb Pixel 3 (Android 10) 540ms
Pixel 4 (Android 10) 405ms
iPhone XS (iOS 12.4.1) 251ms
Modelo de predição de estilo (float16) 4,7 Mb Pixel 3 (Android 10) 86ms 28ms 9,1 ms
Pixel 4 (Android 10) 32ms 12ms 10ms
Modelo de transferência de estilo (float16) 0,4 Mb Pixel 3 (Android 10) 1095ms 545ms 42ms
Pixel 4 (Android 10) 603ms 377ms 42ms

* 4 fios usados.
** 2 threads no iPhone para o melhor desempenho.