TLDR; You want to look at Deconv networks (Convolution transpose) that help regenerate an image using convolution operations. You want to build an encoder-decoder convolution architecture that compresses an image to a latent representation using convolutions and then decodes an image from this compressed representation. For image segmentation, a popular architecture is U-net
.
NOTE: I cant answer for pytorch, so I will he sharing the Tensorflow equivalent. Please feel to ignore the code, but since you are looking for the concept, I can help you with what you need to solve this.
You are trying to generate an image as the output of the network.
A series convolution operation help to Downsample
an image. Since you need an output 2D matrix (gray scale image), you want to Upsample
as well. Such a network is called a Deconv network.
The first series of layers convolve over the input, 'flattening' them into a vector of channels. The next set of layers use 2D Conv Transpose
or Deconv
operations to change the channels back into a 2D matrix (Gray scale image)
Refer to this image for reference -
Here is a sample code that shows you how you can take a (10,3,1) image to a (12,10,1) image using a deconv net.
You can find the conv2dtranspose
layer implementation in pytorch here.
from tensorflow.keras import layers, Model, utils
inp = layers.Input((128,128,1)) ##
x = layers.Conv2D(2, (3,3))(inp) ## Convolution part
x = layers.Conv2D(4, (3,3))(x) ##
x = layers.Conv2D(6, (3,3))(x) ##
##########
x = layers.Conv2DTranspose(6, (3,3))(x)
x = layers.Conv2DTranspose(4, (3,3))(x) ## ## Deconvolution part
out = layers.Conv2DTranspose(1, (3,3))(x) ##
model = Model(inp, out)
utils.plot_model(model, show_shapes=True, show_layer_names=False)
Also, if you are looking for tried and tested architectures in this domain, check out U-net
; U-Net: Convolutional Networks for Biomedical Image Segmentation. This is an encoder-decoder (conv2d, conv2d-transpose)
architecture that uses a concept called skip connections
to avoid information loss and generate better image segmentation masks.