I am not yet sure which part of the model is responsible for ONNX export failure. I am posting model files here. Hopefully that helps you. Ignore the resnet part of the code. We used MobileNetV2 as backbone.
Let me know if you need more information. Thanks for responding.
base_model.py
#------------------------------------------------------------------------------
# Libraries
#------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torchsummary
import os, warnings, sys
from utils import add_flops_counting_methods, flops_to_string
#------------------------------------------------------------------------------
# BaseModel
#------------------------------------------------------------------------------
class BaseModel(nn.Module):
def __init__(self):
super(BaseModel, self).__init__()
def summary(self, input_shape, batch_size=1, device='cpu', print_flops=False):
print("[%s] Network summary..." % (self.__class__.__name__))
torchsummary.summary(self, input_size=input_shape, batch_size=batch_size, device=device)
if print_flops:
input = torch.randn([1, *input_shape], dtype=torch.float)
counter = add_flops_counting_methods(self)
counter.eval().start_flops_count()
counter(input)
print('Flops: {}'.format(flops_to_string(counter.compute_average_flops_cost())))
print('----------------------------------------------------------------')
def init_weights(self):
print("[%s] Initialize weights..." % (self.__class__.__name__))
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
def load_pretrained_model(self, pretrained):
if isinstance(pretrained, str):
print("[%s] Load pretrained model from %s" % (self.__class__.__name__, pretrained))
pretrain_dict = torch.load(pretrained, map_location='cpu')
if 'state_dict' in pretrain_dict:
pretrain_dict = pretrain_dict['state_dict']
elif isinstance(pretrained, dict):
print("[%s] Load pretrained model" % (self.__class__.__name__))
pretrain_dict = pretrained
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
if state_dict[k].shape==v.shape:
model_dict[k] = v
else:
print("[%s]"%(self.__class__.__name__), k, "is ignored due to not matching shape")
else:
print("[%s]"%(self.__class__.__name__), k, "is ignored due to not matching key")
state_dict.update(model_dict)
self.load_state_dict(state_dict)
#------------------------------------------------------------------------------
# BaseBackbone
#------------------------------------------------------------------------------
class BaseBackbone(BaseModel):
def __init__(self):
super(BaseBackbone, self).__init__()
def load_pretrained_model_extended(self, pretrained):
"""
This function is specifically designed for loading pretrain with different in_channels
"""
if isinstance(pretrained, str):
print("[%s] Load pretrained model from %s" % (self.__class__.__name__, pretrained))
pretrain_dict = torch.load(pretrained, map_location='cpu')
if 'state_dict' in pretrain_dict:
pretrain_dict = pretrain_dict['state_dict']
elif isinstance(pretrained, dict):
print("[%s] Load pretrained model" % (self.__class__.__name__))
pretrain_dict = pretrained
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
if state_dict[k].shape!=v.shape:
model_dict[k] = state_dict[k]
model_dict[k][:,:3,...] = v
else:
model_dict[k] = v
else:
print("[%s]"%(self.__class__.__name__), k, "is ignored")
state_dict.update(model_dict)
self.load_state_dict(state_dict)
#------------------------------------------------------------------------------
# BaseBackboneWrapper
#------------------------------------------------------------------------------
class BaseBackboneWrapper(BaseBackbone):
def __init__(self):
super(BaseBackboneWrapper, self).__init__()
def train(self, mode=True):
if mode:
print("[%s] Switch to train mode" % (self.__class__.__name__))
else:
print("[%s] Switch to eval mode" % (self.__class__.__name__))
super(BaseBackboneWrapper, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for module in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(module, nn.BatchNorm2d):
module.eval()
elif isinstance(module, nn.Sequential):
for m in module:
if isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
m.eval()
def init_from_imagenet(self, archname):
pass
def _freeze_stages(self):
pass
MobileNetV2.py
#------------------------------------------------------------------------------
# Libraries
#------------------------------------------------------------------------------
import math, torch, json
import torch.nn as nn
from functools import reduce
#------------------------------------------------------------------------------
# Useful functions
#------------------------------------------------------------------------------
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
#------------------------------------------------------------------------------
# Class of Inverted Residual block
#------------------------------------------------------------------------------
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expansion, dilation=1):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = round(inp * expansion)
self.use_res_connect = self.stride == 1 and inp == oup
if expansion == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
#------------------------------------------------------------------------------
# Class of MobileNetV2
#------------------------------------------------------------------------------
class MobileNetV2(nn.Module):
def __init__(self, alpha=1.0, expansion=6, num_classes=1000):
super(MobileNetV2, self).__init__()
self.num_classes = num_classes
input_channel = 32
last_channel = 1280
interverted_residual_setting = [
# t, c, n, s
[1 , 16, 1, 1],
[expansion, 24, 2, 2],
[expansion, 32, 3, 2],
[expansion, 64, 4, 2],
[expansion, 96, 3, 1],
[expansion, 160, 3, 2],
[expansion, 320, 1, 1],
]
# building first layer
input_channel = _make_divisible(input_channel*alpha, 8)
self.last_channel = _make_divisible(last_channel*alpha, 8) if alpha > 1.0 else last_channel
self.features = [conv_bn(3, input_channel, 2)]
# building inverted residual blocks
for t, c, n, s in interverted_residual_setting:
output_channel = _make_divisible(int(c*alpha), 8)
for i in range(n):
if i == 0:
self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t))
else:
self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t))
input_channel = output_channel
# building last several layers
self.features.append(conv_1x1_bn(input_channel, self.last_channel))
# make it nn.Sequential
self.features = nn.Sequential(*self.features)
# building classifier
if self.num_classes is not None:
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, num_classes),
)
# Initialize weights
self._init_weights()
def forward(self, x, feature_names=None):
# Stage1
x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x)
# Stage2
x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x)
# Stage3
x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x)
# Stage4
x = reduce(lambda x, n: self.features[n](x), list(range(7,14)), x)
# Stage5
x = reduce(lambda x, n: self.features[n](x), list(range(14,19)), x)
# Classification
if self.num_classes is not None:
x = x.mean(dim=(2,3))
x = self.classifier(x)
# Output
return x
def _load_pretrained_model(self, pretrained_file):
pretrain_dict = torch.load(pretrained_file, map_location='cpu')
model_dict = {}
state_dict = self.state_dict()
print("[MobileNetV2] Loading pretrained model...")
for k, v in pretrain_dict.items():
if k in state_dict:
model_dict[k] = v
else:
print(k, "is ignored")
state_dict.update(model_dict)
self.load_state_dict(state_dict)
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
UNet.py
#------------------------------------------------------------------------------
# Libraries
#------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce
from base import BaseModel
from models.backbonds import MobileNetV2, ResNet
#------------------------------------------------------------------------------
# Decoder block
#------------------------------------------------------------------------------
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, block_unit):
super(DecoderBlock, self).__init__()
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, padding=1, stride=2)
self.block_unit = block_unit
def forward(self, input, shortcut):
x = self.deconv(input)
x = torch.cat([x, shortcut], dim=1)
x = self.block_unit(x)
return x
#------------------------------------------------------------------------------
# Class of UNet
#------------------------------------------------------------------------------
class UNet(BaseModel):
def __init__(self, backbone="mobilenetv2", num_classes=2, pretrained_backbone=None):
super(UNet, self).__init__()
if backbone=='mobilenetv2':
alpha = 1.0
expansion = 6
self.backbone = MobileNetV2.MobileNetV2(alpha=alpha, expansion=expansion, num_classes=None)
self._run_backbone = self._run_backbone_mobilenetv2
# Stage 1
channel1 = MobileNetV2._make_divisible(int(96*alpha), 8)
block_unit = MobileNetV2.InvertedResidual(2*channel1, channel1, 1, expansion)
self.decoder1 = DecoderBlock(self.backbone.last_channel, channel1, block_unit)
# Stage 2
channel2 = MobileNetV2._make_divisible(int(32*alpha), 8)
block_unit = MobileNetV2.InvertedResidual(2*channel2, channel2, 1, expansion)
self.decoder2 = DecoderBlock(channel1, channel2, block_unit)
# Stage 3
channel3 = MobileNetV2._make_divisible(int(24*alpha), 8)
block_unit = MobileNetV2.InvertedResidual(2*channel3, channel3, 1, expansion)
self.decoder3 = DecoderBlock(channel2, channel3, block_unit)
# Stage 4
channel4 = MobileNetV2._make_divisible(int(16*alpha), 8)
block_unit = MobileNetV2.InvertedResidual(2*channel4, channel4, 1, expansion)
self.decoder4 = DecoderBlock(channel3, channel4, block_unit)
elif 'resnet' in backbone:
if backbone=='resnet18':
n_layers = 18
elif backbone=='resnet34':
n_layers = 34
elif backbone=='resnet50':
n_layers = 50
elif backbone=='resnet101':
n_layers = 101
else:
raise NotImplementedError
filters = 64
self.backbone = ResNet.get_resnet(n_layers, num_classes=None)
self._run_backbone = self._run_backbone_resnet
block = ResNet.BasicBlock if (n_layers==18 or n_layers==34) else ResNet.Bottleneck
# Stage 1
last_channel = 8*filters if (n_layers==18 or n_layers==34) else 32*filters
channel1 = 4*filters if (n_layers==18 or n_layers==34) else 16*filters
downsample = nn.Sequential(ResNet.conv1x1(2*channel1, channel1), nn.BatchNorm2d(channel1))
block_unit = block(2*channel1, int(channel1/block.expansion), 1, downsample)
self.decoder1 = DecoderBlock(last_channel, channel1, block_unit)
# Stage 2
channel2 = 2*filters if (n_layers==18 or n_layers==34) else 8*filters
downsample = nn.Sequential(ResNet.conv1x1(2*channel2, channel2), nn.BatchNorm2d(channel2))
block_unit = block(2*channel2, int(channel2/block.expansion), 1, downsample)
self.decoder2 = DecoderBlock(channel1, channel2, block_unit)
# Stage 3
channel3 = filters if (n_layers==18 or n_layers==34) else 4*filters
downsample = nn.Sequential(ResNet.conv1x1(2*channel3, channel3), nn.BatchNorm2d(channel3))
block_unit = block(2*channel3, int(channel3/block.expansion), 1, downsample)
self.decoder3 = DecoderBlock(channel2, channel3, block_unit)
# Stage 4
channel4 = filters
downsample = nn.Sequential(ResNet.conv1x1(2*channel4, channel4), nn.BatchNorm2d(channel4))
block_unit = block(2*channel4, int(channel4/block.expansion), 1, downsample)
self.decoder4 = DecoderBlock(channel3, channel4, block_unit)
else:
raise NotImplementedError
self.conv_last = nn.Sequential(
nn.Conv2d(channel4, 3, kernel_size=3, padding=1),
nn.Conv2d(3, num_classes, kernel_size=3, padding=1),
)
# Initialize
self._init_weights()
if pretrained_backbone is not None:
self.backbone._load_pretrained_model(pretrained_backbone)
def forward(self, input):
x1, x2, x3, x4, x5 = self._run_backbone(input)
x = self.decoder1(x5, x4)
x = self.decoder2(x, x3)
x = self.decoder3(x, x2)
x = self.decoder4(x, x1)
x = self.conv_last(x)
x = F.interpolate(x, size=input.shape[-2:], mode='bilinear', align_corners=True)
return x
def _run_backbone_mobilenetv2(self, input):
x = input
# Stage1
x = reduce(lambda x, n: self.backbone.features[n](x), list(range(0,2)), x)
x1 = x
# Stage2
x = reduce(lambda x, n: self.backbone.features[n](x), list(range(2,4)), x)
x2 = x
# Stage3
x = reduce(lambda x, n: self.backbone.features[n](x), list(range(4,7)), x)
x3 = x
# Stage4
x = reduce(lambda x, n: self.backbone.features[n](x), list(range(7,14)), x)
x4 = x
# Stage5
x5 = reduce(lambda x, n: self.backbone.features[n](x), list(range(14,19)), x)
return x1, x2, x3, x4, x5
def _run_backbone_resnet(self, input):
# Stage1
x1 = self.backbone.conv1(input)
x1 = self.backbone.bn1(x1)
x1 = self.backbone.relu(x1)
# Stage2
x2 = self.backbone.maxpool(x1)
x2 = self.backbone.layer1(x2)
# Stage3
x3 = self.backbone.layer2(x2)
# Stage4
x4 = self.backbone.layer3(x3)
# Stage5
x5 = self.backbone.layer4(x4)
return x1, x2, x3, x4, x5
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)