![]() |
![]() |
![]() |
![]() |
This guide demonstrates how to migrate your multi-worker distributed training workflow from TensorFlow 1 to TensorFlow 2.
To perform multi-worker training with CPUs/GPUs:
- In TensorFlow 1, you traditionally use the
tf.estimator.train_and_evaluate
andtf.estimator.Estimator
APIs. - In TensorFlow 2, use the Keras APIs for writing the model, the loss function, the optimizer, and metrics. Then, distribute the training with Keras
Model.fit
API or a custom training loop (withtf.GradientTape
) across multiple workers withtf.distribute.experimental.ParameterServerStrategy
ortf.distribute.MultiWorkerMirroredStrategy
. For more details, refer to the following tutorials:
Setup
Start with some necessary imports and a simple dataset for demonstration purposes:
# The notebook uses a dataset instance for `Model.fit` with
# `ParameterServerStrategy`, which depends on symbols in TF 2.7.
# Install a utility needed for this demonstration
!pip install portpicker
import tensorflow as tf
import tensorflow.compat.v1 as tf1
2022-12-14 03:36:17.649282: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 03:36:17.649372: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 03:36:17.649382: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]
You will need the 'TF_CONFIG'
configuration environment variable for training on multiple machines in TensorFlow. Use 'TF_CONFIG'
to specify the 'cluster'
and the 'task'
s' addresses. (Learn more in the Distributed_training guide.)
import json
import os
tf_config = {
'cluster': {
'chief': ['localhost:11111'],
'worker': ['localhost:12345', 'localhost:23456', 'localhost:21212'],
'ps': ['localhost:12121', 'localhost:13131'],
},
'task': {'type': 'chief', 'index': 0}
}
os.environ['TF_CONFIG'] = json.dumps(tf_config)
Use the del
statement to remove the variable (but in real-world multi-worker training in TensorFlow 1, you won't have to do this):
del os.environ['TF_CONFIG']
TensorFlow 1: Multi-worker distributed training with tf.estimator APIs
The following code snippet demonstrates the canonical workflow of multi-worker training in TF1: you will use a tf.estimator.Estimator
, a tf.estimator.TrainSpec
, a tf.estimator.EvalSpec
, and the tf.estimator.train_and_evaluate
API to distribute the training:
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
def _eval_input_fn():
return tf1.data.Dataset.from_tensor_slices(
(eval_features, eval_labels)).batch(1)
def _model_fn(features, labels, mode):
logits = tf1.layers.Dense(1)(features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
train_spec = tf1.estimator.TrainSpec(input_fn=_input_fn)
eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)
tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpftixggmu INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpftixggmu', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpftixggmu/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 2.9120758, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3... INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmpftixggmu/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T03:36:23 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpftixggmu/model.ckpt-3 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Inference Time : 0.16232s INFO:tensorflow:Finished evaluation at 2022-12-14-03:36:23 INFO:tensorflow:Saving dict for global step 3: global_step = 3, loss = 13.705803 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmpfs/tmp/tmpftixggmu/model.ckpt-3 INFO:tensorflow:Loss for final step: 8.212209. ({'loss': 13.705803, 'global_step': 3}, [])
TensorFlow 2: Multi-worker training with distribution strategies
In TensorFlow 2, distributed training across multiple workers with CPUs, GPUs, and TPUs is done via tf.distribute.Strategy
s.
The following example demonstrates how to use two such strategies: tf.distribute.experimental.ParameterServerStrategy
and tf.distribute.MultiWorkerMirroredStrategy
, both of which are designed for CPU/GPU training with multiple workers.
ParameterServerStrategy
employs a coordinator ('chief'
), which makes it more friendly with the environment in this Colab notebook. You will be using some utilities here to set up the supporting elements essential for a runnable experience here: you will create an in-process cluster, where threads are used to simulate the parameter servers ('ps'
) and workers ('worker'
). For more information about parameter server training, refer to the Parameter server training with ParameterServerStrategy tutorial.
In this example, first define the 'TF_CONFIG'
environment variable with a tf.distribute.cluster_resolver.TFConfigClusterResolver
to provide the cluster information. If you are using a cluster management system for your distributed training, check if it provides 'TF_CONFIG'
for you already, in which case you don't need to explicitly set this environment variable. (Learn more in the Setting up the 'TF_CONFIG'
environment variable section in the Distributed training with TensorFlow guide.)
# Find ports that are available for the `'chief'` (the coordinator),
# `'worker'`s, and `'ps'` (parameter servers).
import portpicker
chief_port = portpicker.pick_unused_port()
worker_ports = [portpicker.pick_unused_port() for _ in range(3)]
ps_ports = [portpicker.pick_unused_port() for _ in range(2)]
# Dump the cluster information to `'TF_CONFIG'`.
tf_config = {
'cluster': {
'chief': ["localhost:%s" % chief_port],
'worker': ["localhost:%s" % port for port in worker_ports],
'ps': ["localhost:%s" % port for port in ps_ports],
},
'task': {'type': 'chief', 'index': 0}
}
os.environ['TF_CONFIG'] = json.dumps(tf_config)
# Use a cluster resolver to bridge the information to the strategy created below.
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
Then, create tf.distribute.Server
s for the workers and parameter servers one-by-one:
# Workers need some inter_ops threads to work properly.
# This is only needed for this notebook to demo. Real servers
# should not need this.
worker_config = tf.compat.v1.ConfigProto()
worker_config.inter_op_parallelism_threads = 4
for i in range(3):
tf.distribute.Server(
cluster_resolver.cluster_spec(),
job_name="worker",
task_index=i,
config=worker_config)
for i in range(2):
tf.distribute.Server(
cluster_resolver.cluster_spec(),
job_name="ps",
task_index=i)
In real-world distributed training, instead of starting all the tf.distribute.Server
s on the coordinator, you will be using multiple machines, and the ones that are designated as "worker"
s and "ps"
(parameter servers) will each run a tf.distribute.Server
. Refer to Clusters in the real world section in the Parameter server training tutorial for more details.
With everything ready, create the ParameterServerStrategy
object:
strategy = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver)
INFO:tensorflow:`tf.distribute.experimental.ParameterServerStrategy` is initialized with cluster_spec: ClusterSpec({'chief': ['localhost:39773'], 'ps': ['localhost:38555', 'localhost:45563'], 'worker': ['localhost:40775', 'localhost:35049', 'localhost:41143']}) INFO:tensorflow:ParameterServerStrategyV2 is now connecting to cluster with cluster_spec: ClusterSpec({'chief': ['localhost:39773'], 'ps': ['localhost:38555', 'localhost:45563'], 'worker': ['localhost:40775', 'localhost:35049', 'localhost:41143']}) INFO:tensorflow:ParameterServerStrategy (CentralStorageStrategy if you are using a single machine) with compute_devices = ['/job:chief/replica:0/task:0/device:GPU:0', '/job:chief/replica:0/task:0/device:GPU:1', '/job:chief/replica:0/task:0/device:GPU:2', '/job:chief/replica:0/task:0/device:GPU:3'], variable_device = '/device:CPU:0' INFO:tensorflow:Number of GPUs on workers: 4
Once you have created a strategy object, define the model, the optimizer, and other variables, and call the Keras Model.compile
within the Strategy.scope
API to distribute the training. (Refer to the Strategy.scope
API docs for more information.)
If you prefer to customize your training by, for instance, defining the forward and backward passes, refer to Training with a custom training loop section in Parameter server training tutorial for more details.
dataset = tf.data.Dataset.from_tensor_slices(
(features, labels)).shuffle(10).repeat().batch(64)
eval_dataset = tf.data.Dataset.from_tensor_slices(
(eval_features, eval_labels)).repeat().batch(1)
with strategy.scope():
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.legacy.Adagrad(learning_rate=0.05)
model.compile(optimizer, "mse")
model.fit(dataset, epochs=5, steps_per_epoch=10)
Epoch 1/5 /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:461: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options. warnings.warn("To make it possible to preserve tf.data options across " 2022-12-14 03:36:24.178957: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 3 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:6" } } attr { key: "output_shapes" value { list { shape { dim { size: 2 } } shape { dim { size: 1 } } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } 2022-12-14 03:36:24.180428: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 3 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:6" } } attr { key: "output_shapes" value { list { shape { dim { size: 2 } } shape { dim { size: 1 } } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } 2022-12-14 03:36:24.180501: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 3 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:6" } } attr { key: "output_shapes" value { list { shape { dim { size: 2 } } shape { dim { size: 1 } } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',). INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',). INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',). INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',). INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',). INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',). INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',). INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',). INFO:tensorflow:Waiting for all global closures to be finished. 10/10 - 4s - loss: 4.2872 - 4s/epoch - 417ms/step Epoch 2/5 INFO:tensorflow:Waiting for all global closures to be finished. 10/10 - 0s - loss: 1.4420 - 71ms/epoch - 7ms/step Epoch 3/5 INFO:tensorflow:Waiting for all global closures to be finished. 10/10 - 0s - loss: 0.6648 - 68ms/epoch - 7ms/step Epoch 4/5 INFO:tensorflow:Waiting for all global closures to be finished. 10/10 - 0s - loss: 0.3278 - 65ms/epoch - 7ms/step Epoch 5/5 INFO:tensorflow:Waiting for all global closures to be finished. 10/10 - 0s - loss: 0.1669 - 66ms/epoch - 7ms/step <keras.callbacks.History at 0x7fcb9c046400>
model.evaluate(eval_dataset, steps=10, return_dict=True)
2022-12-14 03:36:28.736773: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 3 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\025TensorSliceDataset:10" } } attr { key: "output_shapes" value { list { shape { dim { size: 2 } } shape { dim { size: 1 } } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } 2022-12-14 03:36:28.737462: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 3 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\025TensorSliceDataset:10" } } attr { key: "output_shapes" value { list { shape { dim { size: 2 } } shape { dim { size: 1 } } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } 2022-12-14 03:36:28.738132: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 3 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\025TensorSliceDataset:10" } } attr { key: "output_shapes" value { list { shape { dim { size: 2 } } shape { dim { size: 1 } } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',). INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',). INFO:tensorflow:Waiting for all global closures to be finished. 10/10 - 1s - loss: 0.6655 - 1s/epoch - 117ms/step {'loss': 0.6655489206314087}
Partitioners (
tf.distribute.experimental.partitioners
)
ParameterServerStrategy
in TensorFlow 2 supports variable partitioning and offers same partitioners as TensorFlow 1, with less confusing names: -tf.compat.v1.variable_axis_size_partitioner
->tf.distribute.experimental.partitioners.MaxSizePartitioner
: a partitioner that keeps shards under a maximum size). -tf.compat.v1.min_max_variable_partitioner
->tf.distribute.experimental.partitioners.MinSizePartitioner
: a partitioner that allocates a minimum size per shard. -tf.compat.v1.fixed_size_partitioner
->tf.distribute.experimental.partitioners.FixedShardsPartitioner
: a partitioner that allocates a fixed number of shards.
Alternatively, you can use a MultiWorkerMirroredStrategy
object:
# To clean up the `TF_CONFIG` used for `ParameterServerStrategy`.
del os.environ['TF_CONFIG']
strategy = tf.distribute.MultiWorkerMirroredStrategy()
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled. INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3'), communication = CommunicationImplementation.AUTO
You can replace the strategy used above with a MultiWorkerMirroredStrategy
object to perform training with this strategy.
As with the tf.estimator
APIs, since MultiWorkerMirroredStrategy
is a multi-client strategy, there is no easy way to run distributed training in this Colab notebook. Therefore, replacing the code above with this strategy ends up running things locally. The Multi-worker training with Keras Model.fit/a custom training loop tutorials demonstrate how to run multi-worker training with
the 'TF_CONFIG'
variable set up, with two workers on a localhost in Colab. In practice, you would create multiple workers on external IP addresses/ports, and use the 'TF_CONFIG'
variable to specify the cluster configuration for each worker.
Next steps
To learn more about multi-worker distributed training with tf.distribute.experimental.ParameterServerStrategy
and tf.distribute.MultiWorkerMirroredStrategy
in TensorFlow 2, consider the following resources:
- Tutorial: Parameter server training with ParameterServerStrategy and Keras Model.fit/a custom training loop
- Tutorial: Multi-worker training with MultiWorkerMirroredStrategy and Keras Model.fit
- Tutorial: Multi-worker training with MultiWorkerMirroredStrategy and a custom training loop
- Guide: Distributed training with TensorFlow
- Guide: Optimize TensorFlow GPU performance with the TensorFlow Profiler
- Guide: Use a GPU (the Using multiple GPUs section)