author: | liaojh1998 |
score: | 6 / 10 |
-
What is the core idea?
This paper proposed a reinforcement learning-based search method for finding good activation functions, then analyzed and evaluated the best found function, Swish, which is defined as \(x \cdot \sigma(\beta x)\). According to evaluations, Swish is found to reliably outperform ReLU and multiple other activation functions in a wide variety of tasks and networks.
-
How is it realized (technically)?
-
Activation Function Search
The authors restricted the search space to functions composed from functions that takes in 1 or 2 scalar values and output a scalar. This is done by composing from multiple repetitions of the “core unit”:
Note that the inputs of the unary functions are restricted to the preactivation \(x\) and the binary function outputs. This way, the activation function can always take a scalar value as input and output a single scalar value, like ReLU. The full list of functions used in the search are:
- Unary functions:
- Binary functions:
- Note that \(\beta\) in the above functions are trainable parameters.
The choice of the search algorithm depends on the size of the search space. For small search spaces, the algorithm exhaustively enumerate the entire search space, and maintain a list of top performing activation functions ordered by validation accuracy. For large search spaces, the algorithm use an RNN controller to “predict” components of the activation function:
The RNN predicts a single component of the activation function at each timestep, and each prediction is fed back as input to “predict” the component of the next timestep in an autoregressive fashion. This network is trained with reinforcement learning (Proximal Policy Optimization, Schulman et al., 2017) to output functions that have high validation accuracies by using validation accuracy as the reward.
-
Search Evaluations
Each found activation function is evaluated on a child network, which is computationally expensive. To speed up search, the authors used a distributed training scheme to parallelize the training of each child network on different functions. For each search algorithm update step, candidate activation functions are batched for distributed training and their validation accuracies were aggregated.
In particular, ResNet-20 (He et al., 2016a) was used as the child network architecture. Each network was trained for 10 steps on the CIFAR-10 (Krizhevsky & Hinton, 2009) dataset by substituting each activation function instead of ReLU.
The top performing functions were tested on larger networks, specifically ResNet-164 (RN) (He et al., 2016b), Wide ResNet 28-10 (WRN) (Zagoruyko & Komodakis, 2016), and DenseNet 100-12 (DN) (Huang et al., 2017), again by replacing ReLU. The authors used the same hyperparameters described in each model.
The authors found the Swish activation to outperform all other activation functions, so they extensively tested Swish (with both trainable \(\beta\) and \(\beta = 1\)) and several other famous activation functions (especially variants of ELU and Softplus) on challenging real world datasets. For the image classification task, CIFAR-10, CIFAR-100 (Krizhevsky & Hinton, 2009), and ImageNet 2012 classification set (Russakovsky et al., 2015) were used. For the machine translation task, the standard WMT 2014 English→German dataset was used.
-
-
How well does the paper perform?
For the top activation function search, \(x \cdot \sigma(\beta x)\) (Swish) and \(\max(x, \sigma(x))\) were found to consistently match or outperform ReLU in the RN, WRN, and DN network on the CIFAR-10 set. Swish consistently achieved about 0.2% better than ReLU.
The following figure summarizes Swish’s performance against other activation functions in different models:
In general, Swish matches or outperforms other activation functions most of the time by about 0.1%-2% in accuracy in CIFAR-100, by about 0.1-1% in top-5 accuracy in ImageNet, and by about 0.1-4.3 in BLEU score in the WMT datasets.
-
What interesting variants are explored?
The authors found that complicated activation functions consistently underperform simpler activation functions during search. The best activation functions are usually the one that use the preactivation \(x\) as input to the final binary function: \(b(x, g(x))\).
Different values of \(\beta\) allow Swish (\(x \cdot \sigma(\beta x)\)) to model different kinds of Linear Unit activation functions. For example:
- When \(\beta = 0\), Swish becomes the scaled linear function \(f(x) = x \cdot \sigma(0) = x/2\).
- When \(\beta = 1\), Swish becomes the Sigmoid-weighted Linear Unit (SiL) (Elfwing et al., 2017), \(f(x) = x \cdot \sigma(x)\). This function looks similar to ReLU: bounded at negative values of \(x\) and unbounded at positive values of \(x\). However, some small negative values of \(x\) are allowed to pass in the network.
- When \(\beta \to \infty\), the sigmoid component of Swish approaches a 0-1 functions, so Swish becomes a lot like ReLU.
The authors found trained values of \(\beta\) in the networks were spread between 0 and 1.5, which suggest that the flexibility of the trainable \(\beta\) parameter is useful and models took advantage of it.
From a glance, Softplus generally matched Swish in performance for image classification tasks and PReLU and LReLU generally matched Swish in performance for translation tasks.
TL;DR
- Swish (\(x\cdot \sigma(\beta x))\)) can replace ReLU as activation functions in networks for slightly better performance.
- Activation function search by using reinforcement learning to maximize validation accuracy.
- Functions that perform well in smaller networks will generally perform well in larger networks as well.