tf.keras.layers.TorchModuleWrapper

Torch module wrapper layer.

Inherits From: Layer, Operation

TorchModuleWrapper is a wrapper class that can turn any torch.nn.Module into a Keras layer, in particular by making its parameters trackable by Keras.

module torch.nn.Module instance. If it's a LazyModule instance, then its parameters must be initialized before passing the instance to TorchModuleWrapper (e.g. by calling it once).
name The name of the layer (string).

Example:

Here's an example of how the TorchModuleWrapper can be used with vanilla PyTorch modules.

import torch.nn as nn
import torch.nn.functional as F

import keras
from keras.src.layers import TorchModuleWrapper

class Classifier(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Wrap `torch.nn.Module`s with `TorchModuleWrapper`
        # if they contain parameters
        self.conv1 = TorchModuleWrapper(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3))
        )
        self.conv2 = TorchModuleWrapper(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3))
        )
        self.pool = nn.MaxPool2d(kernel_size=(2, 2))
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(p=0.5)
        self.fc = TorchModuleWrapper(nn.Linear(1600, 10))

    def call(self, inputs):
        x = F.relu(self.conv1(inputs))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = self.flatten(x)
        x = self.dropout(x)
        x = self.fc(x)
        return F.softmax(x, dim=1)


model = Classifier()
model.build((1, 28, 28))
print("Output shape:", model(torch.ones(1, 1, 28, 28).to("cuda")).shape)

model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer="adam",
    metrics=["accuracy"]
)
model.fit(train_loader, epochs=5)

input Retrieves the input tensor(s) of a symbolic operation.

Only returns the tensor(s) corresponding to the first time the operation was called.

output Retrieves the output tensor(s) of a layer.

Only returns the tensor(s) corresponding to the first time the operation was called.

Methods

from_config

View source

Creates a layer from its config.

This method is the reverse of get_config, capable of instantiating the same layer from the config dictionary. It does not handle layer connectivity (handled by Network), nor weights (handled by set_weights).

Args
config A Python dictionary, typically the output of get_config.

Returns
A layer instance.

parameters

View source

symbolic_call

View source