{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "headers" }, "source": [ "Project: /overview/_project.yaml\n", "Book: /overview/_book.yaml\n", "\n", "\n", "\n", "\n", "\n", "\n", "{% comment %}\n", "The source of truth file can be found [here]: http://google3/zz\n", "{% endcomment %}" ] }, { "cell_type": "markdown", "metadata": { "id": "metadata" }, "source": [ "
TensorFlow.org で表示\n", " | \n", "Google Colab で実行\n", " | \n", "GitHub でソースを表示 | \n", "ノートブックをダウンロード | \n", "
tf.keras.layers.Layer
メソッドに適用された `tf.compat.v1.keras.utils.track_tf1_style_variables` モデリング shim を使用した移行コードの例について説明します。TF2 モデリング shim の詳細については、[モデルマッピングガイド](./model_mapping.ipynb)を参照してください。\n",
"\n",
"このガイドでは、次の目的で使用できるアプローチについて詳しく説明します。\n",
"\n",
"- 移行されたコードを使用してトレーニングモデルから得られた結果の正確性を検証する\n",
"- TensorFlow バージョン間でコードの数値的等価性を検証する"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TaYgaekzOAHf"
},
"source": [
"## セットアップ"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T21:59:50.988975Z",
"iopub.status.busy": "2022-12-14T21:59:50.988539Z",
"iopub.status.idle": "2022-12-14T21:59:53.501903Z",
"shell.execute_reply": "2022-12-14T21:59:53.500764Z"
},
"id": "FkHX044DzVsd"
},
"outputs": [],
"source": [
"!pip uninstall -y -q tensorflow"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T21:59:53.506247Z",
"iopub.status.busy": "2022-12-14T21:59:53.505997Z",
"iopub.status.idle": "2022-12-14T22:00:16.720014Z",
"shell.execute_reply": "2022-12-14T22:00:16.719089Z"
},
"id": "M1ZgieHtyzKI"
},
"outputs": [],
"source": [
"# Install tf-nightly as the DeterministicRandomTestTool is available only in\n",
"# Tensorflow 2.8\n",
"!pip install -q tf-nightly"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:16.724207Z",
"iopub.status.busy": "2022-12-14T22:00:16.723931Z",
"iopub.status.idle": "2022-12-14T22:00:18.755026Z",
"shell.execute_reply": "2022-12-14T22:00:18.753936Z"
},
"id": "ohYETq4NCX4J"
},
"outputs": [],
"source": [
"!pip install -q tf_slim"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:18.759191Z",
"iopub.status.busy": "2022-12-14T22:00:18.758635Z",
"iopub.status.idle": "2022-12-14T22:00:21.163693Z",
"shell.execute_reply": "2022-12-14T22:00:21.163007Z"
},
"id": "MFey2HxcktP6"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-12-14 22:00:19.013155: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"import tensorflow.compat.v1 as v1\n",
"\n",
"import numpy as np\n",
"import tf_slim as slim\n",
"import sys\n",
"\n",
"\n",
"from contextlib import contextmanager"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:21.167396Z",
"iopub.status.busy": "2022-12-14T22:00:21.166990Z",
"iopub.status.idle": "2022-12-14T22:00:25.720200Z",
"shell.execute_reply": "2022-12-14T22:00:25.719398Z"
},
"id": "OriidSSAmRtW"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cloning into 'models'...\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Enumerating objects: 3590, done.\u001b[K\r\n",
"remote: Counting objects: 0% (1/3590)\u001b[K\r",
"remote: Counting objects: 1% (36/3590)\u001b[K\r",
"remote: Counting objects: 2% (72/3590)\u001b[K\r",
"remote: Counting objects: 3% (108/3590)\u001b[K\r",
"remote: Counting objects: 4% (144/3590)\u001b[K\r",
"remote: Counting objects: 5% (180/3590)\u001b[K\r",
"remote: Counting objects: 6% (216/3590)\u001b[K\r",
"remote: Counting objects: 7% (252/3590)\u001b[K\r",
"remote: Counting objects: 8% (288/3590)\u001b[K\r",
"remote: Counting objects: 9% (324/3590)\u001b[K\r",
"remote: Counting objects: 10% (359/3590)\u001b[K\r",
"remote: Counting objects: 11% (395/3590)\u001b[K\r",
"remote: Counting objects: 12% (431/3590)\u001b[K\r",
"remote: Counting objects: 13% (467/3590)\u001b[K\r",
"remote: Counting objects: 14% (503/3590)\u001b[K\r",
"remote: Counting objects: 15% (539/3590)\u001b[K\r",
"remote: Counting objects: 16% (575/3590)\u001b[K\r",
"remote: Counting objects: 17% (611/3590)\u001b[K\r",
"remote: Counting objects: 18% (647/3590)\u001b[K\r",
"remote: Counting objects: 19% (683/3590)\u001b[K\r",
"remote: Counting objects: 20% (718/3590)\u001b[K\r",
"remote: Counting objects: 21% (754/3590)\u001b[K\r",
"remote: Counting objects: 22% (790/3590)\u001b[K\r",
"remote: Counting objects: 23% (826/3590)\u001b[K\r",
"remote: Counting objects: 24% (862/3590)\u001b[K\r",
"remote: Counting objects: 25% (898/3590)\u001b[K\r",
"remote: Counting objects: 26% (934/3590)\u001b[K\r",
"remote: Counting objects: 27% (970/3590)\u001b[K\r",
"remote: Counting objects: 28% (1006/3590)\u001b[K\r",
"remote: Counting objects: 29% (1042/3590)\u001b[K\r",
"remote: Counting objects: 30% (1077/3590)\u001b[K\r",
"remote: Counting objects: 31% (1113/3590)\u001b[K\r",
"remote: Counting objects: 32% (1149/3590)\u001b[K\r",
"remote: Counting objects: 33% (1185/3590)\u001b[K\r",
"remote: Counting objects: 34% (1221/3590)\u001b[K\r",
"remote: Counting objects: 35% (1257/3590)\u001b[K\r",
"remote: Counting objects: 36% (1293/3590)\u001b[K\r",
"remote: Counting objects: 37% (1329/3590)\u001b[K\r",
"remote: Counting objects: 38% (1365/3590)\u001b[K\r",
"remote: Counting objects: 39% (1401/3590)\u001b[K\r",
"remote: Counting objects: 40% (1436/3590)\u001b[K\r",
"remote: Counting objects: 41% (1472/3590)\u001b[K\r",
"remote: Counting objects: 42% (1508/3590)\u001b[K\r",
"remote: Counting objects: 43% (1544/3590)\u001b[K\r",
"remote: Counting objects: 44% (1580/3590)\u001b[K\r",
"remote: Counting objects: 45% (1616/3590)\u001b[K\r",
"remote: Counting objects: 46% (1652/3590)\u001b[K\r",
"remote: Counting objects: 47% (1688/3590)\u001b[K\r",
"remote: Counting objects: 48% (1724/3590)\u001b[K\r",
"remote: Counting objects: 49% (1760/3590)\u001b[K\r",
"remote: Counting objects: 50% (1795/3590)\u001b[K\r",
"remote: Counting objects: 51% (1831/3590)\u001b[K\r",
"remote: Counting objects: 52% (1867/3590)\u001b[K\r",
"remote: Counting objects: 53% (1903/3590)\u001b[K\r",
"remote: Counting objects: 54% (1939/3590)\u001b[K\r",
"remote: Counting objects: 55% (1975/3590)\u001b[K\r",
"remote: Counting objects: 56% (2011/3590)\u001b[K\r",
"remote: Counting objects: 57% (2047/3590)\u001b[K\r",
"remote: Counting objects: 58% (2083/3590)\u001b[K\r",
"remote: Counting objects: 59% (2119/3590)\u001b[K\r",
"remote: Counting objects: 60% (2154/3590)\u001b[K\r",
"remote: Counting objects: 61% (2190/3590)\u001b[K\r",
"remote: Counting objects: 62% (2226/3590)\u001b[K\r",
"remote: Counting objects: 63% (2262/3590)\u001b[K\r",
"remote: Counting objects: 64% (2298/3590)\u001b[K\r",
"remote: Counting objects: 65% (2334/3590)\u001b[K\r",
"remote: Counting objects: 66% (2370/3590)\u001b[K\r",
"remote: Counting objects: 67% (2406/3590)\u001b[K\r",
"remote: Counting objects: 68% (2442/3590)\u001b[K\r",
"remote: Counting objects: 69% (2478/3590)\u001b[K\r",
"remote: Counting objects: 70% (2513/3590)\u001b[K\r",
"remote: Counting objects: 71% (2549/3590)\u001b[K\r",
"remote: Counting objects: 72% (2585/3590)\u001b[K\r",
"remote: Counting objects: 73% (2621/3590)\u001b[K\r",
"remote: Counting objects: 74% (2657/3590)\u001b[K\r",
"remote: Counting objects: 75% (2693/3590)\u001b[K\r",
"remote: Counting objects: 76% (2729/3590)\u001b[K\r",
"remote: Counting objects: 77% (2765/3590)\u001b[K\r",
"remote: Counting objects: 78% (2801/3590)\u001b[K\r",
"remote: Counting objects: 79% (2837/3590)\u001b[K\r",
"remote: Counting objects: 80% (2872/3590)\u001b[K\r",
"remote: Counting objects: 81% (2908/3590)\u001b[K\r",
"remote: Counting objects: 82% (2944/3590)\u001b[K\r",
"remote: Counting objects: 83% (2980/3590)\u001b[K\r",
"remote: Counting objects: 84% (3016/3590)\u001b[K\r",
"remote: Counting objects: 85% (3052/3590)\u001b[K\r",
"remote: Counting objects: 86% (3088/3590)\u001b[K\r",
"remote: Counting objects: 87% (3124/3590)\u001b[K\r",
"remote: Counting objects: 88% (3160/3590)\u001b[K\r",
"remote: Counting objects: 89% (3196/3590)\u001b[K\r",
"remote: Counting objects: 90% (3231/3590)\u001b[K\r",
"remote: Counting objects: 91% (3267/3590)\u001b[K\r",
"remote: Counting objects: 92% (3303/3590)\u001b[K\r",
"remote: Counting objects: 93% (3339/3590)\u001b[K\r",
"remote: Counting objects: 94% (3375/3590)\u001b[K\r",
"remote: Counting objects: 95% (3411/3590)\u001b[K\r",
"remote: Counting objects: 96% (3447/3590)\u001b[K\r",
"remote: Counting objects: 97% (3483/3590)\u001b[K\r",
"remote: Counting objects: 98% (3519/3590)\u001b[K\r",
"remote: Counting objects: 99% (3555/3590)\u001b[K\r",
"remote: Counting objects: 100% (3590/3590)\u001b[K\r",
"remote: Counting objects: 100% (3590/3590), done.\u001b[K\r\n",
"remote: Compressing objects: 0% (1/3006)\u001b[K\r",
"remote: Compressing objects: 1% (31/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 2% (61/3006)\u001b[K\r",
"remote: Compressing objects: 3% (91/3006)\u001b[K\r",
"remote: Compressing objects: 4% (121/3006)\u001b[K\r",
"remote: Compressing objects: 5% (151/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 6% (181/3006)\u001b[K\r",
"remote: Compressing objects: 7% (211/3006)\u001b[K\r",
"remote: Compressing objects: 8% (241/3006)\u001b[K\r",
"remote: Compressing objects: 9% (271/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 10% (301/3006)\u001b[K\r",
"remote: Compressing objects: 11% (331/3006)\u001b[K\r",
"remote: Compressing objects: 12% (361/3006)\u001b[K\r",
"remote: Compressing objects: 13% (391/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 14% (421/3006)\u001b[K\r",
"remote: Compressing objects: 15% (451/3006)\u001b[K\r",
"remote: Compressing objects: 16% (481/3006)\u001b[K\r",
"remote: Compressing objects: 17% (512/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 18% (542/3006)\u001b[K\r",
"remote: Compressing objects: 19% (572/3006)\u001b[K\r",
"remote: Compressing objects: 20% (602/3006)\u001b[K\r",
"remote: Compressing objects: 21% (632/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 22% (662/3006)\u001b[K\r",
"remote: Compressing objects: 23% (692/3006)\u001b[K\r",
"remote: Compressing objects: 24% (722/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 25% (752/3006)\u001b[K\r",
"remote: Compressing objects: 26% (782/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 27% (812/3006)\u001b[K\r",
"remote: Compressing objects: 28% (842/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 29% (872/3006)\u001b[K\r",
"remote: Compressing objects: 30% (902/3006)\u001b[K\r",
"remote: Compressing objects: 31% (932/3006)\u001b[K\r",
"remote: Compressing objects: 32% (962/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 33% (992/3006)\u001b[K\r",
"remote: Compressing objects: 34% (1023/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 35% (1053/3006)\u001b[K\r",
"remote: Compressing objects: 36% (1083/3006)\u001b[K\r",
"remote: Compressing objects: 37% (1113/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 38% (1143/3006)\u001b[K\r",
"remote: Compressing objects: 39% (1173/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 40% (1203/3006)\u001b[K\r",
"remote: Compressing objects: 41% (1233/3006)\u001b[K\r",
"remote: Compressing objects: 42% (1263/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 43% (1293/3006)\u001b[K\r",
"remote: Compressing objects: 44% (1323/3006)\u001b[K\r",
"remote: Compressing objects: 45% (1353/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 46% (1383/3006)\u001b[K\r",
"remote: Compressing objects: 47% (1413/3006)\u001b[K\r",
"remote: Compressing objects: 48% (1443/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 49% (1473/3006)\u001b[K\r",
"remote: Compressing objects: 50% (1503/3006)\u001b[K\r",
"remote: Compressing objects: 51% (1534/3006)\u001b[K\r",
"remote: Compressing objects: 52% (1564/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 53% (1594/3006)\u001b[K\r",
"remote: Compressing objects: 54% (1624/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 55% (1654/3006)\u001b[K\r",
"remote: Compressing objects: 56% (1684/3006)\u001b[K\r",
"remote: Compressing objects: 57% (1714/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 58% (1744/3006)\u001b[K\r",
"remote: Compressing objects: 59% (1774/3006)\u001b[K\r",
"remote: Compressing objects: 60% (1804/3006)\u001b[K\r",
"remote: Compressing objects: 61% (1834/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 62% (1864/3006)\u001b[K\r",
"remote: Compressing objects: 63% (1894/3006)\u001b[K\r",
"remote: Compressing objects: 63% (1903/3006)\u001b[K\r",
"remote: Compressing objects: 64% (1924/3006)\u001b[K\r",
"remote: Compressing objects: 65% (1954/3006)\u001b[K\r",
"remote: Compressing objects: 66% (1984/3006)\u001b[K\r",
"remote: Compressing objects: 67% (2015/3006)\u001b[K\r",
"remote: Compressing objects: 68% (2045/3006)\u001b[K\r",
"remote: Compressing objects: 69% (2075/3006)\u001b[K\r",
"remote: Compressing objects: 70% (2105/3006)\u001b[K\r",
"remote: Compressing objects: 71% (2135/3006)\u001b[K\r",
"remote: Compressing objects: 72% (2165/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 73% (2195/3006)\u001b[K\r",
"remote: Compressing objects: 74% (2225/3006)\u001b[K\r",
"remote: Compressing objects: 75% (2255/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 76% (2285/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 77% (2315/3006)\u001b[K\r",
"remote: Compressing objects: 78% (2345/3006)\u001b[K\r",
"remote: Compressing objects: 79% (2375/3006)\u001b[K\r",
"remote: Compressing objects: 80% (2405/3006)\u001b[K\r",
"remote: Compressing objects: 81% (2435/3006)\u001b[K\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Compressing objects: 82% (2465/3006)\u001b[K\r",
"remote: Compressing objects: 83% (2495/3006)\u001b[K\r",
"remote: Compressing objects: 84% (2526/3006)\u001b[K\r",
"remote: Compressing objects: 85% (2556/3006)\u001b[K\r",
"remote: Compressing objects: 86% (2586/3006)\u001b[K\r",
"remote: Compressing objects: 87% (2616/3006)\u001b[K\r",
"remote: Compressing objects: 88% (2646/3006)\u001b[K\r",
"remote: Compressing objects: 89% (2676/3006)\u001b[K\r",
"remote: Compressing objects: 90% (2706/3006)\u001b[K\r",
"remote: Compressing objects: 91% (2736/3006)\u001b[K\r",
"remote: Compressing objects: 92% (2766/3006)\u001b[K\r",
"remote: Compressing objects: 93% (2796/3006)\u001b[K\r",
"remote: Compressing objects: 94% (2826/3006)\u001b[K\r",
"remote: Compressing objects: 95% (2856/3006)\u001b[K\r",
"remote: Compressing objects: 96% (2886/3006)\u001b[K\r",
"remote: Compressing objects: 97% (2916/3006)\u001b[K\r",
"remote: Compressing objects: 98% (2946/3006)\u001b[K\r",
"remote: Compressing objects: 99% (2976/3006)\u001b[K\r",
"remote: Compressing objects: 100% (3006/3006)\u001b[K\r",
"remote: Compressing objects: 100% (3006/3006), done.\u001b[K\r\n",
"Receiving objects: 0% (1/3590)\r",
"Receiving objects: 1% (36/3590)\r",
"Receiving objects: 2% (72/3590)\r",
"Receiving objects: 3% (108/3590)\r",
"Receiving objects: 4% (144/3590)\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 5% (180/3590)\r",
"Receiving objects: 6% (216/3590)\r",
"Receiving objects: 7% (252/3590)\r",
"Receiving objects: 8% (288/3590)\r",
"Receiving objects: 9% (324/3590)\r",
"Receiving objects: 10% (359/3590)\r",
"Receiving objects: 11% (395/3590)\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 12% (431/3590)\r",
"Receiving objects: 13% (467/3590)\r",
"Receiving objects: 14% (503/3590)\r",
"Receiving objects: 15% (539/3590)\r",
"Receiving objects: 16% (575/3590)\r",
"Receiving objects: 17% (611/3590)\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 18% (647/3590)\r",
"Receiving objects: 19% (683/3590)\r",
"Receiving objects: 20% (718/3590)\r",
"Receiving objects: 21% (754/3590)\r",
"Receiving objects: 22% (790/3590)\r",
"Receiving objects: 23% (826/3590)\r",
"Receiving objects: 24% (862/3590)\r",
"Receiving objects: 25% (898/3590)\r",
"Receiving objects: 26% (934/3590)\r",
"Receiving objects: 27% (970/3590)\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 28% (1006/3590)\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 29% (1042/3590)\r",
"Receiving objects: 30% (1077/3590)\r",
"Receiving objects: 31% (1113/3590)\r",
"Receiving objects: 32% (1149/3590)\r",
"Receiving objects: 33% (1185/3590)\r",
"Receiving objects: 34% (1221/3590)\r",
"Receiving objects: 35% (1257/3590)\r",
"Receiving objects: 36% (1293/3590)\r",
"Receiving objects: 37% (1329/3590)\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 38% (1365/3590)\r",
"Receiving objects: 39% (1401/3590)\r",
"Receiving objects: 40% (1436/3590)\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 41% (1472/3590), 3.79 MiB | 7.49 MiB/s\r",
"Receiving objects: 42% (1508/3590), 3.79 MiB | 7.49 MiB/s\r",
"Receiving objects: 43% (1544/3590), 3.79 MiB | 7.49 MiB/s\r",
"Receiving objects: 44% (1580/3590), 3.79 MiB | 7.49 MiB/s\r",
"Receiving objects: 45% (1616/3590), 3.79 MiB | 7.49 MiB/s\r",
"Receiving objects: 46% (1652/3590), 3.79 MiB | 7.49 MiB/s\r",
"Receiving objects: 47% (1688/3590), 3.79 MiB | 7.49 MiB/s\r",
"Receiving objects: 48% (1724/3590), 3.79 MiB | 7.49 MiB/s\r",
"Receiving objects: 49% (1760/3590), 3.79 MiB | 7.49 MiB/s\r",
"Receiving objects: 50% (1795/3590), 3.79 MiB | 7.49 MiB/s\r",
"Receiving objects: 51% (1831/3590), 3.79 MiB | 7.49 MiB/s\r",
"Receiving objects: 52% (1867/3590), 3.79 MiB | 7.49 MiB/s\r",
"Receiving objects: 53% (1903/3590), 3.79 MiB | 7.49 MiB/s\r",
"Receiving objects: 54% (1939/3590), 3.79 MiB | 7.49 MiB/s\r",
"Receiving objects: 55% (1975/3590), 3.79 MiB | 7.49 MiB/s\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 55% (2004/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 56% (2011/3590), 17.57 MiB | 17.47 MiB/s\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 57% (2047/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 58% (2083/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 59% (2119/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 60% (2154/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 61% (2190/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 62% (2226/3590), 17.57 MiB | 17.47 MiB/s\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 63% (2262/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 64% (2298/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 65% (2334/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 66% (2370/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 67% (2406/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 68% (2442/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 69% (2478/3590), 17.57 MiB | 17.47 MiB/s\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 70% (2513/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 71% (2549/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 72% (2585/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 73% (2621/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 74% (2657/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 75% (2693/3590), 17.57 MiB | 17.47 MiB/s\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 76% (2729/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 77% (2765/3590), 17.57 MiB | 17.47 MiB/s\r",
"Receiving objects: 78% (2801/3590), 17.57 MiB | 17.47 MiB/s\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 79% (2837/3590), 35.19 MiB | 23.37 MiB/s\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 80% (2872/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 81% (2908/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 82% (2944/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 83% (2980/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 84% (3016/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 85% (3052/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 86% (3088/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 87% (3124/3590), 35.19 MiB | 23.37 MiB/s\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 88% (3160/3590), 35.19 MiB | 23.37 MiB/s\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 89% (3196/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 90% (3231/3590), 35.19 MiB | 23.37 MiB/s\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Receiving objects: 91% (3267/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 92% (3303/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 93% (3339/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 94% (3375/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 95% (3411/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 96% (3447/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 97% (3483/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 98% (3519/3590), 35.19 MiB | 23.37 MiB/s\r",
"remote: Total 3590 (delta 942), reused 1502 (delta 530), pack-reused 0\u001b[K\r\n",
"Receiving objects: 99% (3555/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 100% (3590/3590), 35.19 MiB | 23.37 MiB/s\r",
"Receiving objects: 100% (3590/3590), 47.08 MiB | 25.56 MiB/s, done.\r\n",
"Resolving deltas: 0% (0/942)\r",
"Resolving deltas: 3% (31/942)\r",
"Resolving deltas: 4% (47/942)\r",
"Resolving deltas: 5% (48/942)\r",
"Resolving deltas: 6% (57/942)\r",
"Resolving deltas: 7% (66/942)\r",
"Resolving deltas: 8% (77/942)\r",
"Resolving deltas: 9% (86/942)\r",
"Resolving deltas: 10% (95/942)\r",
"Resolving deltas: 11% (104/942)\r",
"Resolving deltas: 12% (115/942)\r",
"Resolving deltas: 13% (126/942)\r",
"Resolving deltas: 14% (132/942)\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Resolving deltas: 15% (142/942)\r",
"Resolving deltas: 16% (153/942)\r",
"Resolving deltas: 17% (161/942)\r",
"Resolving deltas: 18% (170/942)\r",
"Resolving deltas: 19% (183/942)\r",
"Resolving deltas: 20% (191/942)\r",
"Resolving deltas: 21% (199/942)\r",
"Resolving deltas: 22% (209/942)\r",
"Resolving deltas: 23% (218/942)\r",
"Resolving deltas: 24% (227/942)\r",
"Resolving deltas: 25% (238/942)\r",
"Resolving deltas: 26% (246/942)\r",
"Resolving deltas: 27% (257/942)\r",
"Resolving deltas: 28% (266/942)\r",
"Resolving deltas: 29% (274/942)\r",
"Resolving deltas: 30% (284/942)\r",
"Resolving deltas: 31% (293/942)\r",
"Resolving deltas: 32% (303/942)\r",
"Resolving deltas: 33% (313/942)\r",
"Resolving deltas: 34% (321/942)\r",
"Resolving deltas: 35% (331/942)\r",
"Resolving deltas: 36% (342/942)\r",
"Resolving deltas: 37% (349/942)\r",
"Resolving deltas: 38% (359/942)\r",
"Resolving deltas: 39% (368/942)\r",
"Resolving deltas: 40% (377/942)\r",
"Resolving deltas: 41% (387/942)\r",
"Resolving deltas: 42% (400/942)\r",
"Resolving deltas: 43% (406/942)\r",
"Resolving deltas: 45% (426/942)\r",
"Resolving deltas: 46% (440/942)\r",
"Resolving deltas: 47% (448/942)\r",
"Resolving deltas: 48% (454/942)\r",
"Resolving deltas: 49% (462/942)\r",
"Resolving deltas: 50% (471/942)\r",
"Resolving deltas: 51% (481/942)\r",
"Resolving deltas: 52% (490/942)\r",
"Resolving deltas: 53% (500/942)\r",
"Resolving deltas: 54% (510/942)\r",
"Resolving deltas: 55% (519/942)\r",
"Resolving deltas: 56% (532/942)\r",
"Resolving deltas: 57% (537/942)\r",
"Resolving deltas: 58% (549/942)\r",
"Resolving deltas: 59% (556/942)\r",
"Resolving deltas: 60% (567/942)\r",
"Resolving deltas: 61% (578/942)\r",
"Resolving deltas: 62% (585/942)\r",
"Resolving deltas: 63% (594/942)\r",
"Resolving deltas: 64% (603/942)\r",
"Resolving deltas: 65% (620/942)\r",
"Resolving deltas: 66% (623/942)\r",
"Resolving deltas: 67% (637/942)\r",
"Resolving deltas: 68% (643/942)\r",
"Resolving deltas: 69% (650/942)\r",
"Resolving deltas: 70% (667/942)\r",
"Resolving deltas: 72% (679/942)\r",
"Resolving deltas: 73% (693/942)\r",
"Resolving deltas: 74% (698/942)\r",
"Resolving deltas: 75% (707/942)\r",
"Resolving deltas: 76% (724/942)\r",
"Resolving deltas: 77% (733/942)\r",
"Resolving deltas: 78% (740/942)\r",
"Resolving deltas: 79% (746/942)\r",
"Resolving deltas: 80% (755/942)\r",
"Resolving deltas: 81% (767/942)\r",
"Resolving deltas: 82% (773/942)\r",
"Resolving deltas: 83% (783/942)\r",
"Resolving deltas: 84% (792/942)\r",
"Resolving deltas: 85% (801/942)\r",
"Resolving deltas: 86% (813/942)\r",
"Resolving deltas: 87% (824/942)\r",
"Resolving deltas: 88% (829/942)\r",
"Resolving deltas: 89% (841/942)\r",
"Resolving deltas: 90% (851/942)\r",
"Resolving deltas: 91% (861/942)\r",
"Resolving deltas: 92% (867/942)\r",
"Resolving deltas: 93% (877/942)\r",
"Resolving deltas: 94% (887/942)\r",
"Resolving deltas: 95% (895/942)\r",
"Resolving deltas: 96% (911/942)\r",
"Resolving deltas: 97% (914/942)\r",
"Resolving deltas: 98% (924/942)\r",
"Resolving deltas: 99% (935/942)\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Resolving deltas: 100% (942/942)\r",
"Resolving deltas: 100% (942/942), done.\r\n"
]
}
],
"source": [
"!git clone --depth=1 https://github.com/tensorflow/models.git\n",
"import models.research.slim.nets.inception_resnet_v2 as inception"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TRacYNxnN-nk"
},
"source": [
"重要なフォワードパスコードのチャンクを shim に入れる場合は、TF1.x と同じように動作していることを確認する必要があります。たとえば、TF-Slim Inception-Resnet-v2 モデル全体を次のように shim に入れることを検討してください。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:25.725084Z",
"iopub.status.busy": "2022-12-14T22:00:25.724386Z",
"iopub.status.idle": "2022-12-14T22:00:25.728819Z",
"shell.execute_reply": "2022-12-14T22:00:25.728266Z"
},
"id": "IijQZtxeaErg"
},
"outputs": [],
"source": [
"# TF1 Inception resnet v2 forward pass based on slim layers\n",
"def inception_resnet_v2(inputs, num_classes, is_training):\n",
" with slim.arg_scope(\n",
" inception.inception_resnet_v2_arg_scope(batch_norm_scale=True)):\n",
" return inception.inception_resnet_v2(inputs, num_classes, is_training=is_training)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:25.731923Z",
"iopub.status.busy": "2022-12-14T22:00:25.731426Z",
"iopub.status.idle": "2022-12-14T22:00:26.052072Z",
"shell.execute_reply": "2022-12-14T22:00:26.051445Z"
},
"id": "Z_-Oxg9OlSd4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /tmpfs/tmp/ipykernel_143457/2131234657.py:8: The name tf.keras.utils.track_tf1_style_variables is deprecated. Please use tf.compat.v1.keras.utils.track_tf1_style_variables instead.\n",
"\n"
]
}
],
"source": [
"class InceptionResnetV2(tf.keras.layers.Layer):\n",
" \"\"\"Slim InceptionResnetV2 forward pass as a Keras layer\"\"\"\n",
"\n",
" def __init__(self, num_classes, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.num_classes = num_classes\n",
"\n",
" @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
" def call(self, inputs, training=None):\n",
" is_training = training or False \n",
" \n",
" # Slim does not accept `None` as a value for is_training,\n",
" # Keras will still pass `None` to layers to construct functional models\n",
" # without forcing the layer to always be in training or in inference.\n",
" # However, `None` is generally considered to run layers in inference.\n",
" \n",
" with slim.arg_scope(\n",
" inception.inception_resnet_v2_arg_scope(batch_norm_scale=True)):\n",
" return inception.inception_resnet_v2(\n",
" inputs, self.num_classes, is_training=is_training)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EqFmpktjlvh9"
},
"source": [
"ここでは、このレイヤーは実際にはそのまますぐに完全に機能します(正確な正則化損失トラッキングを備えています)。\n",
"\n",
"ただし、これは当たり前のことではありません。以下のステップに従って、実際に TF1.x と同じように動作していることを確認し、数値的に完全に等価であることを確認します。これらのステップは、フォワードパスのどの部分が TF1.x からの分岐を引き起こしているかを三角測量するのにも役立ちます(モデルの別の部分ではなく、モデルのフォワードパスで分岐が発生しているかどうかを特定します)。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mmgubd9vkevp"
},
"source": [
"## ステップ 1: 変数が 1 回だけ作成されることを確認する\n",
"\n",
"最初に、各呼び出しで変数が再利用され、毎回新しい変数が誤って作成されて使用されないようにモデルが正しく構築されていることを検証する必要があります。たとえば、モデルが新しい Keras レイヤーを作成したり、各フォワードパス呼び出しで tf.Variable
を呼び出す場合、変数のキャプチャに失敗し、毎回新しい変数を作成する可能性が高くなります。\n",
"\n",
"以下は、モデルが新しい変数を作成している場合に、そのことを検出し、モデルのどの部分がそれを行っているかをデバッグするために使用できる 2 つのコンテキストマネージャースコープです。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:26.055774Z",
"iopub.status.busy": "2022-12-14T22:00:26.055201Z",
"iopub.status.idle": "2022-12-14T22:00:26.059938Z",
"shell.execute_reply": "2022-12-14T22:00:26.059399Z"
},
"id": "VMTfTXC0zW97"
},
"outputs": [],
"source": [
"@contextmanager\n",
"def assert_no_variable_creations():\n",
" \"\"\"Assert no variables are created in this context manager scope.\"\"\"\n",
" def invalid_variable_creator(next_creator, **kwargs):\n",
" raise ValueError(\"Attempted to create a new variable instead of reusing an existing one. Args: {}\".format(kwargs))\n",
"\n",
" with tf.variable_creator_scope(invalid_variable_creator):\n",
" yield\n",
"\n",
"@contextmanager\n",
"def catch_and_raise_created_variables():\n",
" \"\"\"Raise all variables created within this context manager scope (if any).\"\"\"\n",
" created_vars = []\n",
" def variable_catcher(next_creator, **kwargs):\n",
" var = next_creator(**kwargs)\n",
" created_vars.append(var)\n",
" return var\n",
"\n",
" with tf.variable_creator_scope(variable_catcher):\n",
" yield\n",
" if created_vars:\n",
" raise ValueError(\"Created vars:\", created_vars)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WOKUtciktQqv"
},
"source": [
"スコープ内で変数を作成しようとすると、最初のスコープ(`assert_no_variable_creations()`)は、すぐにエラーを発生します。これにより、スタックトレースを調べて(対話型デバッグを使用して)、既存の変数を再利用する代わりに、変数を作成したコード行を正確に把握できます。\n",
"\n",
"2 番目のスコープ(`catch_and_raise_created_variables()`)は、変数が作成された場合、スコープの最後で例外を発生させます。この例外には、スコープで作成されたすべての変数のリストが含まれます。これは、一般的なパターンを見つけることができる場合に、モデルが作成しているすべての重みのセットが何であるかを把握するのに役立ちます。ただし、これらの変数が作成された正確なコード行を特定するにはあまり役に立ちません。\n",
"\n",
"以下の両方のスコープを使用して、shim ベースの InceptionResnetV2 レイヤーが最初の呼び出し後に新しい変数を作成せずに再利用していることを確認します。"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:26.063080Z",
"iopub.status.busy": "2022-12-14T22:00:26.062615Z",
"iopub.status.idle": "2022-12-14T22:00:35.541433Z",
"shell.execute_reply": "2022-12-14T22:00:35.540664Z"
},
"id": "O9FAGotiuLbK"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer.py:2212: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.\n",
" warnings.warn('`layer.apply` is deprecated and '\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:332: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.\n",
" warnings.warn('`tf.layers.flatten` is deprecated and '\n"
]
}
],
"source": [
"model = InceptionResnetV2(1000)\n",
"height, width = 299, 299\n",
"num_classes = 1000\n",
"\n",
"inputs = tf.ones( (1, height, width, 3))\n",
"# Create all weights on the first call\n",
"model(inputs)\n",
"\n",
"# Verify that no new weights are created in followup calls\n",
"with assert_no_variable_creations():\n",
" model(inputs)\n",
"with catch_and_raise_created_variables():\n",
" model(inputs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9ylT-EIhu1lK"
},
"source": [
"以下の例では、既存の重みを再利用する代わりに、毎回誤って新しい重みを作成するレイヤーで、これらのデコレータがどのように機能するかを観察できます。"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:35.545831Z",
"iopub.status.busy": "2022-12-14T22:00:35.545564Z",
"iopub.status.idle": "2022-12-14T22:00:35.549983Z",
"shell.execute_reply": "2022-12-14T22:00:35.549409Z"
},
"id": "gXqhPQWWtMAw"
},
"outputs": [],
"source": [
"class BrokenScalingLayer(tf.keras.layers.Layer):\n",
" \"\"\"Scaling layer that incorrectly creates new weights each time:\"\"\"\n",
"\n",
" @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
" def call(self, inputs):\n",
" var = tf.Variable(initial_value=2.0)\n",
" bias = tf.Variable(initial_value=2.0, name='bias')\n",
" return inputs * var + bias"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:35.553198Z",
"iopub.status.busy": "2022-12-14T22:00:35.552741Z",
"iopub.status.idle": "2022-12-14T22:00:35.571330Z",
"shell.execute_reply": "2022-12-14T22:00:35.570722Z"
},
"id": "ztUKlMdGvHSq"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/tmpfs/tmp/ipykernel_143457/1128777590.py\", line 7, in tf.Variable
呼び出しを使用してしまう場合、最初に作成されていないかどうかを確認してから、既存のものを再利用します。\n",
"2. (tf.compat.v1.layers
とは対照的に)毎回フォワードパスで Keras レイヤーまたはモデルを直接作成してしまう場合、最初に作成されていないかどうかを確認して、既存のものを再利用します。\n",
"3. tf.compat.v1.layers
の上に構築されていて、すべての compat.v1.layers
に明示的な名前を割り当てたり、名前付き `variable_scope` 内で`compat.v1 .layer` の使用をラップできず、自動生成されたレイヤー名が各モデル呼び出しでインクリメントされてしまう場合、tf.compat.v1.layers
の使用をすべてラップする shim でデコレートされたメソッド内に名前付きの tf.compat.v1.variable_scope
を配置します。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V4iZLV9BnwKM"
},
"source": [
"## ステップ 2: 変数の数、名前、形状が一致していることを確認する\n",
"\n",
"2 番目のステップは、TF2 で実行されているレイヤーが対応するコードが TF1.x と同じ形状で同じ数の重みを作成することを確認することです。\n",
"\n",
"以下に示すように、これらが一致することを確認するために手動での確認と、単体テストでのプログラムによる確認を組み合わせることができます。"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:35.598932Z",
"iopub.status.busy": "2022-12-14T22:00:35.598324Z",
"iopub.status.idle": "2022-12-14T22:00:53.434358Z",
"shell.execute_reply": "2022-12-14T22:00:53.433655Z"
},
"id": "m_aqag5fpun5"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:1694: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.\n",
" warnings.warn('`layer.apply` is deprecated and '\n"
]
}
],
"source": [
"# Build the forward pass inside a TF1.x graph, and \n",
"# get the counts, shapes, and names of the variables\n",
"graph = tf.Graph()\n",
"with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:\n",
" height, width = 299, 299\n",
" num_classes = 1000\n",
" inputs = tf.ones( (1, height, width, 3))\n",
"\n",
" out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)\n",
"\n",
" tf1_variable_names_and_shapes = {\n",
" var.name: (var.trainable, var.shape) for var in tf.compat.v1.global_variables()}\n",
" num_tf1_variables = len(tf.compat.v1.global_variables())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WT1-cm99vfNU"
},
"source": [
"次に、TF2 の shim によりラップされたレイヤーに対して同じことを実行します。重みを取得する前に、モデルも複数回呼び出されることに注意してください。これは、変数の再利用を効果的にテストするために行われます。"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:53.438655Z",
"iopub.status.busy": "2022-12-14T22:00:53.438122Z",
"iopub.status.idle": "2022-12-14T22:00:56.628958Z",
"shell.execute_reply": "2022-12-14T22:00:56.628217Z"
},
"id": "S7ND-lBSqmnE"
},
"outputs": [],
"source": [
"height, width = 299, 299\n",
"num_classes = 1000\n",
"\n",
"model = InceptionResnetV2(num_classes)\n",
"# The weights will not be created until you call the model\n",
"\n",
"inputs = tf.ones( (1, height, width, 3))\n",
"# Call the model multiple times before checking the weights, to verify variables\n",
"# get reused rather than accidentally creating additional variables\n",
"out, endpoints = model(inputs, training=False)\n",
"out, endpoints = model(inputs, training=False)\n",
"\n",
"# Grab the name: shape mapping and the total number of variables separately,\n",
"# because in TF2 variables can be created with the same name\n",
"num_tf2_variables = len(model.variables)\n",
"tf2_variable_names_and_shapes = {\n",
" var.name: (var.trainable, var.shape) for var in model.variables}"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:56.632897Z",
"iopub.status.busy": "2022-12-14T22:00:56.632661Z",
"iopub.status.idle": "2022-12-14T22:00:56.636503Z",
"shell.execute_reply": "2022-12-14T22:00:56.635953Z"
},
"id": "pY2P_4wqsOYw"
},
"outputs": [],
"source": [
"# Verify that the variable counts, names, and shapes all match:\n",
"assert num_tf1_variables == num_tf2_variables\n",
"assert tf1_variable_names_and_shapes == tf2_variable_names_and_shapes"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N4YKJzSVwWkc"
},
"source": [
"Shim ベースの InceptionResnetV2 レイヤーは、このテストに合格しています。ただし、一致しない場合は、差分(テキストまたはその他)を実行して、差分がどこにあるかを確認できます。\n",
"\n",
"これにより、モデルのどの部分が期待どおりに動作していないかが分かります。Eager execution では、pdb、インタラクティブなデバッグ、およびブレークポイントを使用して、疑わしいと思われるモデルの部分を掘り下げ、問題が何なのかをより深くデバッグできます。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2gYrt-_0xpRM"
},
"source": [
"### トラブルシューティング\n",
"\n",
"- 明示的な tf.Variable
呼び出しと Keras レイヤー/モデルによって直接作成された変数の名前に細心の注意を払ってください。他のすべてが正常に機能している場合でもそれらの変数名生成セマンティクスは、TF1.x Graph と Eager execution および tf.function
などの TF2 関数との間でわずかに異なる可能性があるためです。このような場合は、わずかに異なる命名セマンティクスを考慮してテストを調整してください。\n",
"\n",
"- TF1.x の変数コレクションによってキャプチャされた場合でも、tf.Variable
、tf.keras.layers.Layer
、または tf.keras.Model
がトレーニングループのフォワードパスが TF2 変数リストにない場合があります。これを修正するには、フォワードパスが作成する変数/レイヤー/モデルをモデルのインスタンス属性に割り当てます。詳細については、[こちら](https://www.tensorflow.org/guide/keras/custom_layers_and_models)を参照してください。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fOQJ_hUGnzkq"
},
"source": [
"## ステップ 3: すべての変数をリセットし、ランダム性をすべて無効にして数値の等価性を確認する\n",
"\n",
"次のステップでは、(推論中などに)乱数生成が含まれないようにモデルを修正するときに、実際の出力と正則化損失トラッキングの両方の数値的等価性を検証します。\n",
"\n",
"正確な方法は、特定のモデルに依存する場合がありますが、ほとんどのモデル(このモデルなど)では、次の方法でこれを行うことができます。\n",
"\n",
"1. 重みをランダム性なしで同じ値に初期化します。そのためには、作成後に固定値にリセットします。\n",
"2. モデルを推論モードで実行して、ランダム性の原因となる可能性のあるドロップアウトレイヤーがトリガーされないようにします。\n",
"\n",
"次のコードは、この方法で TF1.x と TF2 の結果を比較する方法を示しています。"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:56.640336Z",
"iopub.status.busy": "2022-12-14T22:00:56.639711Z",
"iopub.status.idle": "2022-12-14T22:01:22.153784Z",
"shell.execute_reply": "2022-12-14T22:01:22.153114Z"
},
"id": "kL4PzD2Cxzmp"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Regularization loss: 0.001182976\n"
]
},
{
"data": {
"text/plain": [
"array([0.00299837, 0.00299837, 0.00299837, 0.00299837, 0.00299837],\n",
" dtype=float32)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph = tf.Graph()\n",
"with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:\n",
" height, width = 299, 299\n",
" num_classes = 1000\n",
" inputs = tf.ones( (1, height, width, 3))\n",
"\n",
" out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)\n",
"\n",
" # Rather than running the global variable initializers,\n",
" # reset all variables to a constant value\n",
" var_reset = tf.group([var.assign(tf.ones_like(var) * 0.001) for var in tf.compat.v1.global_variables()])\n",
" sess.run(var_reset)\n",
"\n",
" # Grab the outputs & regularization loss\n",
" reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)\n",
" tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))\n",
" tf1_output = sess.run(out)\n",
"\n",
"print(\"Regularization loss:\", tf1_regularization_loss)\n",
"tf1_output[0][:5]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IKkoM_x72rUa"
},
"source": [
"TF2 の結果を取得します。"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:01:22.157647Z",
"iopub.status.busy": "2022-12-14T22:01:22.157122Z",
"iopub.status.idle": "2022-12-14T22:01:26.605366Z",
"shell.execute_reply": "2022-12-14T22:01:26.604754Z"
},
"id": "kb086gJwzsNo"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Regularization loss: tf.Tensor(0.0011829757, shape=(), dtype=float32)\n"
]
},
{
"data": {
"text/plain": [
"tf.compat.v1.Session
で、シードが指定されていない場合、乱数の生成は、ランダムな演算が追加された時点で Graph にある演算の数と、その Graph の実行回数に依存します。Eager execution では、ステートフルな乱数の生成は、グローバルシード、演算のランダムシード、および指定されたランダムシードを使用した演算が実行される回数に依存します。詳細については、tf.random.set_seed
を参照してください。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BQbb8Hyk5YVi"
},
"source": [
"次の [`v1.keras.utils.DeterministicRandomTestTool`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/keras/utils/DeterministicRandomTestTool) クラスは、コンテキストマネージャ `scope()` を提供し、 TF1 Graphs/Session と Eager execution の両方でステートフルなランダム演算が同じシードを使用できるようになります。\n",
"\n",
"このツールには、次の 2 つのテストモードがあります。\n",
"\n",
"1. `constant` は、呼び出された回数に関係なく、1 つの演算ごとに同じシードを使用します。\n",
"2. `num_random_ops` は、以前に観測されたステートフルなランダム演算の数を演算シードとして使用します。\n",
"\n",
"これは、変数の作成と初期化に使用されるステートフルなランダム演算と、計算で使用されるステートフルなランダム演算(ドロップアウトレイヤーなど)の両方に適用されます。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MoyZenhGHDA-"
},
"source": [
"このツールを使用して、Session と Eager execution の間でステートフルな乱数生成を一致させる方法を示すために、3 つのランダムテンソルを生成します。"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:01:26.622973Z",
"iopub.status.busy": "2022-12-14T22:01:26.622334Z",
"iopub.status.idle": "2022-12-14T22:01:26.661486Z",
"shell.execute_reply": "2022-12-14T22:01:26.660851Z"
},
"id": "DDFfjrbXEWED"
},
"outputs": [
{
"data": {
"text/plain": [
"(array([[2.5063772],\n",
" [2.7488918],\n",
" [1.4839486]], dtype=float32),\n",
" array([[2.5063772, 2.7488918, 1.4839486],\n",
" [1.5633398, 2.1358476, 1.3693532],\n",
" [0.3598416, 1.8287641, 2.5314465]], dtype=float32),\n",
" array([[2.5063772, 2.7488918, 1.4839486],\n",
" [1.5633398, 2.1358476, 1.3693532],\n",
" [0.3598416, 1.8287641, 2.5314465]], dtype=float32))"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"random_tool = v1.keras.utils.DeterministicRandomTestTool()\n",
"with random_tool.scope():\n",
" graph = tf.Graph()\n",
" with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:\n",
" a = tf.random.uniform(shape=(3,1))\n",
" a = a * 3\n",
" b = tf.random.uniform(shape=(3,3))\n",
" b = b * 3\n",
" c = tf.random.uniform(shape=(3,3))\n",
" c = c * 3\n",
" graph_a, graph_b, graph_c = sess.run([a, b, c])\n",
"\n",
"graph_a, graph_b, graph_c"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:01:26.664616Z",
"iopub.status.busy": "2022-12-14T22:01:26.664113Z",
"iopub.status.idle": "2022-12-14T22:01:26.683442Z",
"shell.execute_reply": "2022-12-14T22:01:26.682855Z"
},
"id": "o9bkdPuTFpYr"
},
"outputs": [
{
"data": {
"text/plain": [
"(tf.keras.layers.Layer
の周りのデコレータで Eager execution されている `InceptionResnetV2` モデルが、TF1 Graph と Session で実行されているスリムネットワークと数値的に一致することを確認できました。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xpOAei5vRAPa"
},
"source": [
"注意: `num_random_ops` モードで `DeterministicRandomTestTool` を使用する場合、数値的等価性のためにテスト時に tf.keras.layers.Layer
メソッドデコレータを直接使用して呼び出すことを推薦します。Keras functional モデルまたは他の Keras モデル内に埋め込むと、TF1.x Graph/Session と Eager execution を比較するときに、ステートフルなランダム演算の順位トレースに違いが生じ、正確に一致させるのが難しくなる可能性があります。\n",
"\n",
"たとえば、`InceptionResnetV2` レイヤーを `training=True` で直接呼び出すと、変数の初期化がネットワークの作成順位に従ってドロップアウト順位でインターリーブされます。\n",
"\n",
"一方、最初に tf.keras.layers.Layer
デコレータを Keras functional モデルに配置してから、そのモデルを `training=True` で呼び出すことは、すべての変数を初期化し、ドロップアウトレイヤーを使用することと同じです。これにより、異なる順位トレースと異なる乱数セットが生成されます。\n",
"\n",
"ただし、デフォルトの `mode='constant'` は、これらの順位トレースの違いに影響されず、レイヤーを Keras functional モデルに埋め込む場合でも、追加の作業なしで渡せます。"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:02:17.775531Z",
"iopub.status.busy": "2022-12-14T22:02:17.775049Z",
"iopub.status.idle": "2022-12-14T22:02:41.250009Z",
"shell.execute_reply": "2022-12-14T22:02:41.249305Z"
},
"id": "0dSR4ZNvYNYm"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Regularization loss: 1.2239965\n"
]
}
],
"source": [
"random_tool = v1.keras.utils.DeterministicRandomTestTool()\n",
"with random_tool.scope():\n",
" graph = tf.Graph()\n",
" with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:\n",
" height, width = 299, 299\n",
" num_classes = 1000\n",
" inputs = tf.ones( (1, height, width, 3))\n",
"\n",
" out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=True)\n",
"\n",
" # Initialize the variables\n",
" sess.run(tf.compat.v1.global_variables_initializer())\n",
"\n",
" # Get the outputs & regularization losses\n",
" reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)\n",
" tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))\n",
" tf1_output = sess.run(out)\n",
"\n",
" print(\"Regularization loss:\", tf1_regularization_loss)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:02:41.253336Z",
"iopub.status.busy": "2022-12-14T22:02:41.253083Z",
"iopub.status.idle": "2022-12-14T22:02:48.864187Z",
"shell.execute_reply": "2022-12-14T22:02:48.863465Z"
},
"id": "iMPMMnPtYUY7"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer.py:1345: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.\n",
" warnings.warn('`layer.updates` will be removed in a future version. '\n",
"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/legacy_tf_layers/base.py:627: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.\n",
" self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Regularization loss: tf.Tensor(1.2239964, shape=(), dtype=float32)\n"
]
}
],
"source": [
"height, width = 299, 299\n",
"num_classes = 1000\n",
"\n",
"random_tool = v1.keras.utils.DeterministicRandomTestTool()\n",
"with random_tool.scope():\n",
" keras_input = tf.keras.Input(shape=(height, width, 3))\n",
" layer = InceptionResnetV2(num_classes)\n",
" model = tf.keras.Model(inputs=keras_input, outputs=layer(keras_input))\n",
"\n",
" inputs = tf.ones((1, height, width, 3))\n",
" tf2_output, endpoints = model(inputs, training=True)\n",
"\n",
" # Get the regularization loss\n",
" tf2_regularization_loss = tf.math.add_n(model.losses)\n",
"\n",
"print(\"Regularization loss:\", tf2_regularization_loss)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:02:48.867563Z",
"iopub.status.busy": "2022-12-14T22:02:48.867302Z",
"iopub.status.idle": "2022-12-14T22:02:48.871796Z",
"shell.execute_reply": "2022-12-14T22:02:48.871225Z"
},
"id": "jf46KUVyYUY8"
},
"outputs": [],
"source": [
"# Verify that the regularization loss and output both match\n",
"# when using the DeterministicRandomTestTool\n",
"np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)\n",
"np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hWXHjtkiZ09V"
},
"source": [
"## ステップ 3b 、4b(オプション): 既存のチェックポイントを使用したテスト\n",
"\n",
"上記のステップ 3 またはステップ 4 の後、既存の名前ベースのチェックポイントがある場合は、そこから開始するときに数値的等価性テストを実行すると便利です。これにより、レガシーチェックポイントの読み込みが正しく機能していることと、モデル自体が正しく機能していることの両方をテストできます。[TF1.x チェックポイントの再利用ガイド](./reuse_checkpoints.ipynb)では、既存の TF1.x チェックポイントを再利用して TF2 チェックポイントに移行する方法について説明されています。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v6i3MFmGcxYx"
},
"source": [
"## 追加のテストとトラブルシューティング\n",
"\n",
"数値的等価性テストをさらに追加する場合、勾配計算(またはオプティマイザーの更新)の一致を検証するテストを追加することもできます。\n",
"\n",
"バックプロパゲーションと勾配の計算は、モデルのフォワードパスよりも浮動小数点の数値が不安定になる傾向があります。これは、トレーニングの分離されていない部分の等価性をテストすると、完全に Eager execution を実行した場合と TF1 Graph との間に大きな数値上の違いが見られる可能性があることを意味します。これは、Graph 内の部分式をより少ない数学的演算に置き換えたりする TensorFlow Graph の最適化が原因である可能性があります。\n",
"\n",
"これが当てはまる可能性があるかどうかを特定するには、TF1 コードを、純粋な Eager 計算ではなく、tf.function
(TF1 Graph のようなグラフ最適化パスを適用する)内で行われている TF2 計算と比較できます。または、TF1 計算の前に tf.config.optimizer.set_experimental_options
を使用して `\"arithmetic_optimization\"` などの最適化パスを無効にして、結果が TF2 計算結果と数値的に近い値になるかどうかを確認することもできます。実際のトレーニングの実行では、パフォーマンス上の理由から最適化パスを有効にして tf.function
を使用することを推薦しますが、数値等価性の単体テストではそれらを無効にすることが役立つ場合があります。\n",
"\n",
"同様に、tf.compat.v1.train
オプティマイザーと TF2 オプティマイザーは、それらが表す数式が同じであっても、TF2 オプティマイザーにはわずかに異なる浮動小数点数値プロパティがあります。これがトレーニングの実行で問題になる可能性は低いですが、等価性単体テストではより高い数値許容誤差が必要になる場合があります。"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "validate_correctness.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 0
}