Understanding Depth-wise Separable Convolutions

Read on Black Box ML

This blog post is a small excerpt from my work on paper-annotations for the task of question answering. This repo contains a collection of important question-answering papers, implemented from scratch in pytorch with detailed explanation of various concepts/components introduced in the respective papers. The illustrations in this blog post have been created by me using https://www.diagrams.net/. You can find the other references below.

Depthwise Separable Convolutions

Depthwise separable convolutions serve the same purpose as normal convolutions with the only difference being that they are faster because they reduce the number of multiplication operations. This is done by breaking the convolution operation into two parts: depthwise convolution and pointwise convolution.

Depthwise separable convolutions are used rather than traditional ones, as we observe that it is memory efficient and has better generalization.

Let's understand why depthwise convolutions are faster than traditional convolution. Traditional convolution can be visualized as,

Let's count the number of multiplications in a traditional convolution operation.

The number of multiplications for a single convolution operation is the number of elements inside the kernel.This is DKX DK X M = D2KX M

To get the output feature map, we slide or convolve this kernel over the input. Given the output dimensions, we perform Do covolutions along the width and the height of the input image. Therefore, the number of multiplications per kernel are D2oXD2kX M

These calculations are for a single kernel. In convolutional neural networks, we usually use multiple kernels. Each kernel is expected to extract a unique feature from the input. If we use N such filters, then number of multiplications become N X D2oX D2kX M

Depthwise convolution

In depthwise convolution we perform convolution using kernels of dimension DKX DKX1 Therefore the number of multiplications in a single convolution operation would be D2kX1. Therefore the number of multiplications in a single convolution operation would be D2kX1. If the output dimension is Do, then the number of multiplications per kernel are D2k D2o If there are M input channels, we need to use M such kernels, one kernel for each input channel to get the all the features.For M kernels , we then get D2kX D2oX M multiplications.

Pointwise convolution

This part takes the output from depthwise convolution and performs convolution operation with a kernel of size 1X1XN ,where N s the desired number of output features/channels. Here similarly, Multiplications per 1 convolution operation =1X1XM
Multiplications per kernel = D2oXM
For N output features = NX D2o X M Adding up the number of multiplications from both the phases, we get,

=N. Do2.M + Dk2.D2o.M
=D2o.M(N + D2k)
Comparing this with traditional convolutions,
= D O 2 ​ . M . D K 2 ​ . N D O 2 ​ . M (N+D K 2 ​ ) ​= D K 2 ​ 1 ​ + N 1 ​ ​

This clearly shows that the number of computations in depthwise separable convolutions are lesser than traditional ones. In code, the depthwise phase of the convolution is done by assigning groups as in_channels. According to the documentation,

At groups= in_channels, each `nput channel is convolved with its own set of filters, of size: ⌊ o u t _ c h a n n e l s i n _ c h a n n e l s ⌋ ⌊ in_channels out_channels ​ ⌋

Implementation

Following is an implementation for the layer discussed above. This is a standalone implementation of the layer and can be plugged into any application/larger model where it is used as a component.

from torch import nn

class DepthwiseSeparableConvolution(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size):

        super().__init__()


        self.depthwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels,
                                    kernel_size=kernel_size, groups=in_channels, padding=kernel_size//2)

        self.pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)


    def forward(self, x):

        # Interpretations
        # x = [bs, seq_len, emb_dim] for NLP applications
        # x = [C_in, H_in, W_in] for CV applications
        x = self.pointwise_conv(self.depthwise_conv(x))

        return x