AI Model Efficiency Toolkit (AIMET) Forum

Exception while exporting Quantized model

I could successfully run quantization on my model (which is very complex one) - using AIMET Pytorch APIs. However the export failed. It exported the model (.pth file). But it could not export the json file and the onnx format file. Here is the error. Can someone help?

simulation.export('./quantization', 'unet-fine-tuned', input_shape=(1, 3, 800, 800), set_onnx_layer_names=True)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-21-c399a863d0ce> in <module>()
----> 1 simulation.export('./quantization', 'unet-fine-tuned', input_shape=(1, 3, 800, 800), set_onnx_layer_names=True)

10 frames
/usr/local/lib/python3.6/dist-packages/aimet_torch/quantsim.py in export(self, path, filename_prefix, input_shape, set_onnx_layer_names)
    268         onnx_path = os.path.join(path, filename_prefix + '.onnx')
    269         dummy_input = utils.create_rand_tensors_given_shapes(input_shape)
--> 270         torch.onnx.export(model_to_export, tuple(dummy_input), onnx_path)
    271 
    272         #  Set the onnx layer names

/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py in export(*args, **kwargs)
     23 def export(*args, **kwargs):
     24     from torch.onnx import utils
---> 25     return utils.export(*args, **kwargs)
     26 
     27 

/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, strip_doc_string)
    129             operator_export_type=operator_export_type, opset_version=opset_version,
    130             _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
--> 131             strip_doc_string=strip_doc_string)
    132 
    133 

/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string)
    361                                                         output_names, operator_export_type,
    362                                                         example_outputs, propagate,
--> 363                                                         _retain_param_name, do_constant_folding)
    364 
    365         # TODO: Don't allocate a in-memory string for the protobuf

/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py in _model_to_graph(model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding, _disable_torch_constant_prop)
    276 
    277     graph = _optimize_graph(graph, operator_export_type,
--> 278                             _disable_torch_constant_prop=_disable_torch_constant_prop)
    279 
    280     # NB: ONNX requires complete information about output types, which might be

/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop)
    186 
    187     if operator_export_type != OperatorExportTypes.RAW:
--> 188         graph = torch._C._jit_pass_onnx(graph, operator_export_type)
    189         torch._C._jit_pass_lint(graph)
    190         torch._C._jit_pass_onnx_peephole(graph)

/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py in _run_symbolic_function(*args, **kwargs)
     48 def _run_symbolic_function(*args, **kwargs):
     49     from torch.onnx import utils
---> 50     return utils._run_symbolic_function(*args, **kwargs)
     51 
     52 

/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py in _run_symbolic_function(g, n, inputs, env, operator_export_type)
    587                     return None
    588                 fn = getattr(torch.onnx.symbolic, op_name)
--> 589                 return fn(g, *inputs, **attrs)
    590 
    591         elif ns == "prim":

/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic.py in wrapper(g, *args)
    128             # some args may be optional, so the length may be smaller
    129             assert len(arg_descriptors) >= len(args)
--> 130             args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
    131             return fn(g, *args)
    132         # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround

/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic.py in <listcomp>(.0)
    128             # some args may be optional, so the length may be smaller
    129             assert len(arg_descriptors) >= len(args)
--> 130             args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
    131             return fn(g, *args)
    132         # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround

/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic.py in _parse_arg(value, desc)
     88             for v in value.node().inputs():
     89                 if v.node().kind() != 'onnx::Constant':
---> 90                     raise RuntimeError("Failed to export an ONNX attribute, "
     91                                        "since it's not constant, please try to make "
     92                                        "things (e.g., kernel size) static if possible")

RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible

Looks like this is an issue with onnx export itself. Can anyone suggest help?

import torch
import torchvision
dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
torch.onnx.export(simulation.model, dummy_input, "unet_quantized.onnx", verbose=True)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-30-3e5f038c9ec7> in <module>()
      4 
      5 dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
----> 6 torch.onnx.export(simulation.model, dummy_input, "unet_quantized.onnx", verbose=True)

9 frames
/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py in export(*args, **kwargs)
     23 def export(*args, **kwargs):
     24     from torch.onnx import utils
---> 25     return utils.export(*args, **kwargs)
     26 
     27 

/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, strip_doc_string)
    129             operator_export_type=operator_export_type, opset_version=opset_version,
    130             _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
--> 131             strip_doc_string=strip_doc_string)
    132 
    133 

/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string)
    361                                                         output_names, operator_export_type,
    362                                                         example_outputs, propagate,
--> 363                                                         _retain_param_name, do_constant_folding)
    364 
    365         # TODO: Don't allocate a in-memory string for the protobuf

/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py in _model_to_graph(model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding, _disable_torch_constant_prop)
    276 
    277     graph = _optimize_graph(graph, operator_export_type,
--> 278                             _disable_torch_constant_prop=_disable_torch_constant_prop)
    279 
    280     # NB: ONNX requires complete information about output types, which might be

/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop)
    186 
    187     if operator_export_type != OperatorExportTypes.RAW:
--> 188         graph = torch._C._jit_pass_onnx(graph, operator_export_type)
    189         torch._C._jit_pass_lint(graph)
    190         torch._C._jit_pass_onnx_peephole(graph)

/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py in _run_symbolic_function(*args, **kwargs)
     48 def _run_symbolic_function(*args, **kwargs):
     49     from torch.onnx import utils
---> 50     return utils._run_symbolic_function(*args, **kwargs)
     51 
     52 

/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py in _run_symbolic_function(g, n, inputs, env, operator_export_type)
    587                     return None
    588                 fn = getattr(torch.onnx.symbolic, op_name)
--> 589                 return fn(g, *inputs, **attrs)
    590 
    591         elif ns == "prim":

/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic.py in wrapper(g, *args)
    128             # some args may be optional, so the length may be smaller
    129             assert len(arg_descriptors) >= len(args)
--> 130             args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
    131             return fn(g, *args)
    132         # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround

/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic.py in <listcomp>(.0)
    128             # some args may be optional, so the length may be smaller
    129             assert len(arg_descriptors) >= len(args)
--> 130             args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
    131             return fn(g, *args)
    132         # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround

/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic.py in _parse_arg(value, desc)
     88             for v in value.node().inputs():
     89                 if v.node().kind() != 'onnx::Constant':
---> 90                     raise RuntimeError("Failed to export an ONNX attribute, "
     91                                        "since it's not constant, please try to make "
     92                                        "things (e.g., kernel size) static if possible")

RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible

Looks like this issue is present in Pytorch v1.1.0 (https://github.com/pytorch/pytorch/issues/19374) and is fixed in v1.6.0 or later (check https://github.com/pytorch/pytorch/pull/20116 and https://github.com/pytorch/pytorch/commit/93d5503f34ab54c915b640d3bc45d41e6afe0dd2)

AIMET works with Pytorch v1.1.0. Can you guys suggest some workaround for this?

Could you please post the relevant snippet from the model that is preventing export to ONNX?
Basically to debug this we should first debug why the model (without any AIMET involvement) cannot be exported to ONNX. In the past, we have tweaked the model definition (without needing an architecture change) to get past these kinds of issues.

I am not yet sure which part of the model is responsible for ONNX export failure. I am posting model files here. Hopefully that helps you. Ignore the resnet part of the code. We used MobileNetV2 as backbone.

Let me know if you need more information. Thanks for responding.

base_model.py

#------------------------------------------------------------------------------
#   Libraries
#------------------------------------------------------------------------------
import torch
import torch.nn as nn

import torchsummary
import os, warnings, sys
from utils import add_flops_counting_methods, flops_to_string


#------------------------------------------------------------------------------
#   BaseModel
#------------------------------------------------------------------------------
class BaseModel(nn.Module):
	def __init__(self):
		super(BaseModel, self).__init__()

	def summary(self, input_shape, batch_size=1, device='cpu', print_flops=False):
		print("[%s] Network summary..." % (self.__class__.__name__))
		torchsummary.summary(self, input_size=input_shape, batch_size=batch_size, device=device)
		if print_flops:
			input = torch.randn([1, *input_shape], dtype=torch.float)
			counter = add_flops_counting_methods(self)
			counter.eval().start_flops_count()
			counter(input)
			print('Flops:  {}'.format(flops_to_string(counter.compute_average_flops_cost())))
			print('----------------------------------------------------------------')

	def init_weights(self):
		print("[%s] Initialize weights..." % (self.__class__.__name__))
		for m in self.modules():
			if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
				if m.bias is not None:
					m.bias.data.zero_()
			elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 0)
			elif isinstance(m, nn.Linear):
				m.weight.data.normal_(0, 0.01)
				m.bias.data.zero_()

	def load_pretrained_model(self, pretrained):
		if isinstance(pretrained, str):
			print("[%s] Load pretrained model from %s" % (self.__class__.__name__, pretrained))
			pretrain_dict = torch.load(pretrained, map_location='cpu')
			if 'state_dict' in pretrain_dict:
				pretrain_dict = pretrain_dict['state_dict']
		elif isinstance(pretrained, dict):
			print("[%s] Load pretrained model" % (self.__class__.__name__))
			pretrain_dict = pretrained

		model_dict = {}
		state_dict = self.state_dict()
		for k, v in pretrain_dict.items():
			if k in state_dict:
				if state_dict[k].shape==v.shape:
					model_dict[k] = v
				else:
					print("[%s]"%(self.__class__.__name__), k, "is ignored due to not matching shape")
			else:
				print("[%s]"%(self.__class__.__name__), k, "is ignored due to not matching key")
		state_dict.update(model_dict)
		self.load_state_dict(state_dict)


#------------------------------------------------------------------------------
#   BaseBackbone
#------------------------------------------------------------------------------
class BaseBackbone(BaseModel):
	def __init__(self):
		super(BaseBackbone, self).__init__()

	def load_pretrained_model_extended(self, pretrained):
		"""
		This function is specifically designed for loading pretrain with different in_channels
		"""
		if isinstance(pretrained, str):
			print("[%s] Load pretrained model from %s" % (self.__class__.__name__, pretrained))
			pretrain_dict = torch.load(pretrained, map_location='cpu')
			if 'state_dict' in pretrain_dict:
				pretrain_dict = pretrain_dict['state_dict']
		elif isinstance(pretrained, dict):
			print("[%s] Load pretrained model" % (self.__class__.__name__))
			pretrain_dict = pretrained

		model_dict = {}
		state_dict = self.state_dict()
		for k, v in pretrain_dict.items():
			if k in state_dict:
				if state_dict[k].shape!=v.shape:
					model_dict[k] = state_dict[k]
					model_dict[k][:,:3,...] = v
				else:
					model_dict[k] = v
			else:
				print("[%s]"%(self.__class__.__name__), k, "is ignored")
		state_dict.update(model_dict)
		self.load_state_dict(state_dict)


#------------------------------------------------------------------------------
#  BaseBackboneWrapper
#------------------------------------------------------------------------------
class BaseBackboneWrapper(BaseBackbone):
	def __init__(self):
		super(BaseBackboneWrapper, self).__init__()

	def train(self, mode=True):
		if mode:
			print("[%s] Switch to train mode" % (self.__class__.__name__))
		else:
			print("[%s] Switch to eval mode" % (self.__class__.__name__))

		super(BaseBackboneWrapper, self).train(mode)
		self._freeze_stages()
		if mode and self.norm_eval:
			for module in self.modules():
				# trick: eval have effect on BatchNorm only
				if isinstance(module, nn.BatchNorm2d):
					module.eval()
				elif isinstance(module, nn.Sequential):
					for m in module:
						if isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
							m.eval()

	def init_from_imagenet(self, archname):
		pass

	def _freeze_stages(self):
		pass

MobileNetV2.py

#------------------------------------------------------------------------------
#  Libraries
#------------------------------------------------------------------------------
import math, torch, json
import torch.nn as nn
from functools import reduce


#------------------------------------------------------------------------------
#  Useful functions
#------------------------------------------------------------------------------
def _make_divisible(v, divisor, min_value=None):
	if min_value is None:
		min_value = divisor
	new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
	# Make sure that round down does not go down by more than 10%.
	if new_v < 0.9 * v:
		new_v += divisor
	return new_v


def conv_bn(inp, oup, stride):
	return nn.Sequential(
		nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
		nn.BatchNorm2d(oup),
		nn.ReLU6(inplace=True)
	)


def conv_1x1_bn(inp, oup):
	return nn.Sequential(
		nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
		nn.BatchNorm2d(oup),
		nn.ReLU6(inplace=True)
	)


#------------------------------------------------------------------------------
#  Class of Inverted Residual block
#------------------------------------------------------------------------------
class InvertedResidual(nn.Module):
	def __init__(self, inp, oup, stride, expansion, dilation=1):
		super(InvertedResidual, self).__init__()
		self.stride = stride
		assert stride in [1, 2]

		hidden_dim = round(inp * expansion)
		self.use_res_connect = self.stride == 1 and inp == oup

		if expansion == 1:
			self.conv = nn.Sequential(
				# dw
				nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
				nn.BatchNorm2d(hidden_dim),
				nn.ReLU6(inplace=True),
				# pw-linear
				nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
				nn.BatchNorm2d(oup),
			)
		else:
			self.conv = nn.Sequential(
				# pw
				nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
				nn.BatchNorm2d(hidden_dim),
				nn.ReLU6(inplace=True),
				# dw
				nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
				nn.BatchNorm2d(hidden_dim),
				nn.ReLU6(inplace=True),
				# pw-linear
				nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
				nn.BatchNorm2d(oup),
			)

	def forward(self, x):
		if self.use_res_connect:
			return x + self.conv(x)
		else:
			return self.conv(x)


#------------------------------------------------------------------------------
#  Class of MobileNetV2
#------------------------------------------------------------------------------
class MobileNetV2(nn.Module):
	def __init__(self, alpha=1.0, expansion=6, num_classes=1000):
		super(MobileNetV2, self).__init__()
		self.num_classes = num_classes
		input_channel = 32
		last_channel = 1280
		interverted_residual_setting = [
			# t, c, n, s
			[1        , 16, 1, 1],
			[expansion, 24, 2, 2],
			[expansion, 32, 3, 2],
			[expansion, 64, 4, 2],
			[expansion, 96, 3, 1],
			[expansion, 160, 3, 2],
			[expansion, 320, 1, 1],
		]

		# building first layer
		input_channel = _make_divisible(input_channel*alpha, 8)
		self.last_channel = _make_divisible(last_channel*alpha, 8) if alpha > 1.0 else last_channel
		self.features = [conv_bn(3, input_channel, 2)]

		# building inverted residual blocks
		for t, c, n, s in interverted_residual_setting:
			output_channel = _make_divisible(int(c*alpha), 8)
			for i in range(n):
				if i == 0:
					self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t))
				else:
					self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t))
				input_channel = output_channel

		# building last several layers
		self.features.append(conv_1x1_bn(input_channel, self.last_channel))

		# make it nn.Sequential
		self.features = nn.Sequential(*self.features)

		# building classifier
		if self.num_classes is not None:
			self.classifier = nn.Sequential(
				nn.Dropout(0.2),
				nn.Linear(self.last_channel, num_classes),
			)

		# Initialize weights
		self._init_weights()


	def forward(self, x, feature_names=None):
		# Stage1
		x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x)

		# Stage2
		x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x)

		# Stage3
		x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x)

		# Stage4
		x = reduce(lambda x, n: self.features[n](x), list(range(7,14)), x)

		# Stage5
		x = reduce(lambda x, n: self.features[n](x), list(range(14,19)), x)

		# Classification
		if self.num_classes is not None:
			x = x.mean(dim=(2,3))
			x = self.classifier(x)
			
		# Output
		return x


	def _load_pretrained_model(self, pretrained_file):
		pretrain_dict = torch.load(pretrained_file, map_location='cpu')
		model_dict = {}
		state_dict = self.state_dict()
		print("[MobileNetV2] Loading pretrained model...")
		for k, v in pretrain_dict.items():
			if k in state_dict:
				model_dict[k] = v
			else:
				print(k, "is ignored")
		state_dict.update(model_dict)
		self.load_state_dict(state_dict)


	def _init_weights(self):
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
				m.weight.data.normal_(0, math.sqrt(2. / n))
				if m.bias is not None:
					m.bias.data.zero_()
			elif isinstance(m, nn.BatchNorm2d):
				m.weight.data.fill_(1)
				m.bias.data.zero_()
			elif isinstance(m, nn.Linear):
				n = m.weight.size(1)
				m.weight.data.normal_(0, 0.01)
				m.bias.data.zero_()

UNet.py

#------------------------------------------------------------------------------
#   Libraries
#------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce

from base import BaseModel
from models.backbonds import MobileNetV2, ResNet


#------------------------------------------------------------------------------
#   Decoder block
#------------------------------------------------------------------------------
class DecoderBlock(nn.Module):
	def __init__(self, in_channels, out_channels, block_unit):
		super(DecoderBlock, self).__init__()
		self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, padding=1, stride=2)
		self.block_unit = block_unit

	def forward(self, input, shortcut):
		x = self.deconv(input)
		x = torch.cat([x, shortcut], dim=1)
		x = self.block_unit(x)
		return x


#------------------------------------------------------------------------------
#   Class of UNet
#------------------------------------------------------------------------------
class UNet(BaseModel):
	def __init__(self, backbone="mobilenetv2", num_classes=2, pretrained_backbone=None):
		super(UNet, self).__init__()
		if backbone=='mobilenetv2':
			alpha = 1.0
			expansion = 6
			self.backbone = MobileNetV2.MobileNetV2(alpha=alpha, expansion=expansion, num_classes=None)
			self._run_backbone = self._run_backbone_mobilenetv2
			# Stage 1
			channel1 = MobileNetV2._make_divisible(int(96*alpha), 8)
			block_unit = MobileNetV2.InvertedResidual(2*channel1, channel1, 1, expansion)
			self.decoder1 = DecoderBlock(self.backbone.last_channel, channel1, block_unit)
			# Stage 2
			channel2 = MobileNetV2._make_divisible(int(32*alpha), 8)
			block_unit = MobileNetV2.InvertedResidual(2*channel2, channel2, 1, expansion)
			self.decoder2 = DecoderBlock(channel1, channel2, block_unit)
			# Stage 3
			channel3 = MobileNetV2._make_divisible(int(24*alpha), 8)
			block_unit = MobileNetV2.InvertedResidual(2*channel3, channel3, 1, expansion)
			self.decoder3 = DecoderBlock(channel2, channel3, block_unit)
			# Stage 4
			channel4 = MobileNetV2._make_divisible(int(16*alpha), 8)
			block_unit = MobileNetV2.InvertedResidual(2*channel4, channel4, 1, expansion)
			self.decoder4 = DecoderBlock(channel3, channel4, block_unit)

		elif 'resnet' in backbone:
			if backbone=='resnet18':
				n_layers = 18
			elif backbone=='resnet34':
				n_layers = 34
			elif backbone=='resnet50':
				n_layers = 50
			elif backbone=='resnet101':
				n_layers = 101
			else:
				raise NotImplementedError
			filters = 64
			self.backbone = ResNet.get_resnet(n_layers, num_classes=None)
			self._run_backbone = self._run_backbone_resnet
			block = ResNet.BasicBlock if (n_layers==18 or n_layers==34) else ResNet.Bottleneck
			# Stage 1
			last_channel = 8*filters if (n_layers==18 or n_layers==34) else 32*filters
			channel1 = 4*filters if (n_layers==18 or n_layers==34) else 16*filters
			downsample = nn.Sequential(ResNet.conv1x1(2*channel1, channel1), nn.BatchNorm2d(channel1))
			block_unit = block(2*channel1, int(channel1/block.expansion), 1, downsample)
			self.decoder1 = DecoderBlock(last_channel, channel1, block_unit)
			# Stage 2
			channel2 = 2*filters if (n_layers==18 or n_layers==34) else 8*filters
			downsample = nn.Sequential(ResNet.conv1x1(2*channel2, channel2), nn.BatchNorm2d(channel2))
			block_unit = block(2*channel2, int(channel2/block.expansion), 1, downsample)
			self.decoder2 = DecoderBlock(channel1, channel2, block_unit)
			# Stage 3
			channel3 = filters if (n_layers==18 or n_layers==34) else 4*filters
			downsample = nn.Sequential(ResNet.conv1x1(2*channel3, channel3), nn.BatchNorm2d(channel3))
			block_unit = block(2*channel3, int(channel3/block.expansion), 1, downsample)
			self.decoder3 = DecoderBlock(channel2, channel3, block_unit)
			# Stage 4
			channel4 = filters
			downsample = nn.Sequential(ResNet.conv1x1(2*channel4, channel4), nn.BatchNorm2d(channel4))
			block_unit = block(2*channel4, int(channel4/block.expansion), 1, downsample)
			self.decoder4 = DecoderBlock(channel3, channel4, block_unit)

		else:
			raise NotImplementedError

		self.conv_last = nn.Sequential(
			nn.Conv2d(channel4, 3, kernel_size=3, padding=1),
			nn.Conv2d(3, num_classes, kernel_size=3, padding=1),
		)

		# Initialize
		self._init_weights()
		if pretrained_backbone is not None:
			self.backbone._load_pretrained_model(pretrained_backbone)


	def forward(self, input):
		x1, x2, x3, x4, x5 = self._run_backbone(input)
		x = self.decoder1(x5, x4)
		x = self.decoder2(x, x3)
		x = self.decoder3(x, x2)
		x = self.decoder4(x, x1)
		x = self.conv_last(x)
		x = F.interpolate(x, size=input.shape[-2:], mode='bilinear', align_corners=True)
		return x


	def _run_backbone_mobilenetv2(self, input):
		x = input
		# Stage1
		x = reduce(lambda x, n: self.backbone.features[n](x), list(range(0,2)), x)
		x1 = x
		# Stage2
		x = reduce(lambda x, n: self.backbone.features[n](x), list(range(2,4)), x)
		x2 = x
		# Stage3
		x = reduce(lambda x, n: self.backbone.features[n](x), list(range(4,7)), x)
		x3 = x
		# Stage4
		x = reduce(lambda x, n: self.backbone.features[n](x), list(range(7,14)), x)
		x4 = x
		# Stage5
		x5 = reduce(lambda x, n: self.backbone.features[n](x), list(range(14,19)), x)
		return x1, x2, x3, x4, x5


	def _run_backbone_resnet(self, input):
		# Stage1
		x1 = self.backbone.conv1(input)
		x1 = self.backbone.bn1(x1)
		x1 = self.backbone.relu(x1)
		# Stage2
		x2 = self.backbone.maxpool(x1)
		x2 = self.backbone.layer1(x2)
		# Stage3
		x3 = self.backbone.layer2(x2)
		# Stage4
		x4 = self.backbone.layer3(x3)
		# Stage5
		x5 = self.backbone.layer4(x4)
		return x1, x2, x3, x4, x5


	def _init_weights(self):
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
				if m.bias is not None:
					m.bias.data.zero_()
			elif isinstance(m, nn.BatchNorm2d):
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 0)