Implementation of network construction code using U-NET as an example

Original link: https://www.insidentally.com/articles/000031/

When I was reading the U-Net paper recently, I saw the code for building a network model from scratch on the Internet. The code is indirect enough and the structure is relatively complete, so record the learning results.

The focus of this paper is on how to implement the code. The details in the U-Net paper are not covered, and the discussion on the paper can be moved.

Links to learning resources are at the end of the article.

U-net model

First of all, there is a simple understanding of the model:

U-net model

U-net model

For the construction of the U-net model, the main thing is the implementation of the convolution layer and the transposed convolution (down-sampling and up-sampling), and how to realize the connection of the mirror corresponding parts. Readers are requested to understand the U-net model and keep in mind the number of channels at each step.

Code

According to the common problem-solving steps in industry or competitions, it mainly includes data set acquisition, model construction, model training (selection of loss function, model optimization), and verification of training results. Therefore, the code will be interpreted from these aspects in the following.

Data set acquisition

Dataset URL: Carvana Image Masking Challenge | Kaggle

Baidu website link: https://pan.baidu.com/s/1bhKCyd226__fDhWbYLGPJQ Extraction code: 4t3y

Among them, readers are required to first use the data from the training set as a validation set according to their own needs.

read data set

 1
2
3
4
 import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
twenty one
twenty two
twenty three
 class CarvanaDataset ( Dataset ):
def __init__ ( self, image_dir, mask_dir, transform= None ):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = os.listdir(image_dir)

def __len__ ( self ):
return len (self.images)

def __getitem__ ( self, idx ):
image_path = os.path.join(self.image_dir, self.images[idx])
mask_path = os.path.join(self.mask_dir, self.images[idx].replace( '.jpg' , '_mask.gif' ))
image = np.array( Image.open (image_path).convert( 'RGB' ))
mask = np.array( Image.open (mask_path).convert( 'L' ), dtype=np.float32)
mask[mask == 255.0 ] = 1.0

if self.transform is not None :
augmentations = self.transform(image=image, mask=mask)
image = augmentations[ 'image' ]
mask = augmentations[ 'mask' ]

return image, mask

For the convenience of subsequent operations, directly inherit the Dataset, and then return the image and the corresponding mask.

os.listdir(path) returns the files (folders) under the specified path. In the above code, returns the list corresponding to the entire training set images.

os.path.join() This operation directly obtains the storage path corresponding to each image

image.open().convert(), this function converts the image according to the specified mode, such as RGB image, or grayscale image. (I haven’t found the specific official explanation yet, if there is an explanation on the official website, please enlighten me)

mask[mask==255.0] = 1.0 to facilitate the calculation of the subsequent sigmoid() function? (doubtful)

Model building

First observe the construction of the U-net model. Before the pool layer, there will always be two convolutions to increase the number of channels of the original image. So first create the class DoubleConv.

 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
 import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv (nn.Module):
def __init__(self, in_channels ,out_channels ):
super (DoubleConv,self).__init__()
self.conv=nn.Sequential(
nn.Conv2d(in_channels,out_channels, 3 , 1 , 1 , bias= False ),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace= True ),
nn.Conv2d(out_channels,out_channels, 3 , 1 , 1 , bias= False ),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace= True ),
)

def forward ( self,x ):
return self.conv(x)

Next, observe the U-net model. It is very elegant because of its symmetry, and the processing of each step is very regular. It is because of such a law that we can write code without being so cumbersome and repetitive. Convolution-pooling-convolution-pooling.

 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
twenty one
 class UNET (nn.Module):
def __init__ (
self, in_channels= 3 , out_channels= 1 , features=[ 64 , 128 , 256 , 512 ]
):
super (UNET,self).__init__()
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size= 2 ,stride= 2 )

for feature in features:
self.downs.append(DoubleConv(in_channels,feature))
in_channels = feature

for feature in reversed (features):
self.ups.append(
nn.ConvTranspose2d(feature* 2 ,feature,kernel_size= 2 ,stride= 2 )
)
self.ups.append(DoubleConv(feature* 2 , feature))

self.bottleneck = DoubleConv(features[- 1 ], features[- 1 ]* 2 )
self.final_conv = nn.Conv2d(features[ 0 ], out_channels, kernel_size= 1 )

In the above code, the entire u-Net model is divided into a convolutional layer (downsampling), a transposed convolutional layer (upsampling), a pooling layer, a bottleneck layer, and a final convolutional layer.

For the downsampling stage, use ModelList(), then determine the input and output channels of each convolution, and then use a loop structure.

 1
2
3
4
5
6
7
 features=[ 64 , 128 , 256 , 512 ]

self.downs = nn.ModuleList()

for feature in features:
self.downs.append(DoubleConv(in_channels,feature))
in_channels = feature
 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
twenty one
twenty two
 def forward ( self,x ):
skip_connections = []

for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)

x = self.bottleneck(x)
skip_connections = skip_connections[::- 1 ]

for idx in range ( 0 , len (self.ups) , 2 ):
x = self.ups[idx](x)
skip_connection = skip_connections[idx// 2 ]

if x.shape != skip_connection.shape:
x=TF.resize(x,size=skip_connection.shape[ 2 :])

concat_skip = torch.cat((skip_connection,x), dim= 1 )
x = self.ups[idx+ 1 ](concat_skip)

return self.final_conv(x)

In the forward propagation, it should be noted that each layer of U-net has a skip-connnection

skip-connections=[] , save the convolved x to a list and connect it when upsampling

skip_connections=skip_connections[::-1], the order of saving is opposite to the order of use, so it needs to be reversed

concat_skip=torch.cat((skip_connection, x),dim=1) to connect the two

some practical operations

I think when we write code, why the code structure looks messy, mainly because we have not been able to integrate every function and operation. Here is a specific example.

 1
2
3
 def save_checkpoint ( state,filename= 'my_checkpoint.pth.tar' ):
print ( '=>Saving checkpoint' )
torch.save(state, filename)

function to save the trained model

torch.save() official website torch.save() annotation

 1
2
3
 def load_checkpoint ( checkpoint, model ):
print ( '=>Loading checkpoint' )
model.load_state_dict(checkpoint[ 'state_dict' ])

Load the model, you can retrain the model that was not trained last time

 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
twenty one
twenty two
twenty three
twenty four
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
 def get_loader (
train_dir,
train_maskdir,
val_dir,
val_maskdir,
batch_size,
train_transform,
val_transform,
num_workers= 1 ,
pin_momory= True ,
):
train_ds = CarvanaDataset(
image_dir=train_dir,
mask_dir=train_maskdir,
transform=train_transform
)

train_loader = DataLoader(
train_ds,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle= True
)

val_ds = CarvanaDataset(
image_dir=val_dir,
mask_dir=val_maskdir,
transform=val_transform
)

val_loader = DataLoader(
val_ds,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle= False
)

return train_loader,val_loader

Common functions for loading data, among which CarvanaDataset is customized, or Dataset() can be used directly

Parameters in the DataLoader() function:

pin_memory ( bool , optional ) – If True , the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.

Train the model

Determination of hyperparameters:

 1
2
3
4
5
6
7
8
9
10
11
12
13
 LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 240
BATCH_SIZE = 16
NUM_EPOCHS = 3
NUM_WORKER = 2
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "data/train/"
TRAIN_MASK_DIR = "data/train_masks/"
VAL_IMG_DIR = "data/val/"
VAL_MASK_DIR = "data/val_masks/"

Training function train_fn()

 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
twenty one
 def train_fn ( loader, model, optimizer, loss_fn, scaler ):
loop = tqdm(loader)

for batch_idx, (data, targets) in enumerate (loop):
data = data.to(device=DEVICE)
targets = targets.float(). unsqueeze ( 1 ).to(device=DEVICE)

#forward
'''Mixed precision training'''
with torch.cuda.amp.autocast():
preds = model(data)
loss = loss_fn(preds,targets)

#backward
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

#update tqdm loop
loop.set_postfix(loss=loss.item())

loop = tqdm(loader) is simply understood as a fast and scalable python progress bar

loop.set_postfix() sets the output content of the progress bar

Specifically about the use of tqdm, I have not studied in depth

In the above code, it should be noted that the code for forward propagation and back propagation is different from the common code, because this code introduces mixed precision training, please refer to it for details.

some thoughts

In the video, you can clearly understand how to build a model from scratch, how to run it, and how to implement some additional functions in the process of use, which is of great benefit to a novice like me.

And there are many other projects on the main GitHub page of up, basically starting from scratch. You can try to participate in the kaggle competition step by step in the future.

If you have friends who are just getting started with deep learning, you can also exchange and learn together.

This article is reprinted from: https://www.insidentally.com/articles/000031/
This site is for inclusion only, and the copyright belongs to the original author.

Leave a Comment