(即插即用模块-特征处理部分) 四十二、(2024 TPAMI) FreqFusion 频率特征融合
文章目录
- 1、Frequency-Aware Feature Fusion
- 2、代码实现
paper:Frequency-aware Feature Fusion for Dense Image Prediction
Code:https://github.com/Linwei-Chen/FreqFusion
1、Frequency-Aware Feature Fusion
FreqFusion 的动机主要针对现有特征融合方法在密集图像预测任务中存在的两个主要问题:类内不一致性 (Intra-category inconsistency): 由于物体内部不同区域特征差异较大,简单上采样和插值会导致特征在物体内部变化剧烈,导致类内相似度降低,影响分类准确率。边界位移 (Boundary displacement): 低分辨率特征在降采样过程中会丢失高频细节信息,导致边界模糊,影响目标定位精度。所以这篇论文提出一种利用频率域分析来优化特征融合过程 频率感知特征融合(Frequency-Aware Feature Fusion)。
FreqFusion 的实现过程:
- 初始融合: 将低分辨率和高分辨率特征压缩并融合,生成压缩特征 Z,作为后续三个生成器的输入。
- 自适应低通滤波器 (ALPF) 生成器: 对 Z 应用 ALPF,得到平滑的高分辨率特征 Y1。
- 偏移量生成器: 对 Z 应用偏移量生成器,得到偏移量 u, v,并将高分辨率特征 Y1 中的特征点沿着偏移量方向重采样,得到最终的融合特征 Y2。
- 自适应高通滤波器 (AHPF) 生成器: 对 Z 应用 AHPF,得到增强边界信息的高分辨率特征 Y3,并将其与最终的融合特征 Y2 进行元素级加和,得到最终的融合特征。
Frequency-Aware Feature Fusion 结构图:
2、代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import warnings
import numpy as nptry:from mmcv.ops.carafe import normal_init, xavier_init, carafe
except ImportError:def xavier_init(module: nn.Module,gain: float = 1,bias: float = 0,distribution: str = 'normal') -> None:assert distribution in ['uniform', 'normal']if hasattr(module, 'weight') and module.weight is not None:if distribution == 'uniform':nn.init.xavier_uniform_(module.weight, gain=gain)else:nn.init.xavier_normal_(module.weight, gain=gain)if hasattr(module, 'bias') and module.bias is not None:nn.init.constant_(module.bias, bias)def carafe(x, normed_mask, kernel_size, group=1, up=1):b, c, h, w = x.shape_, m_c, m_h, m_w = normed_mask.shapeprint('x', x.shape)print('normed_mask', normed_mask.shape)# assert m_c == kernel_size ** 2 * up ** 2assert m_h == up * hassert m_w == up * wpad = kernel_size // 2# print(pad)pad_x = F.pad(x, pad=[pad] * 4, mode='reflect')# print(pad_x.shape)unfold_x = F.unfold(pad_x, kernel_size=(kernel_size, kernel_size), stride=1, padding=0)# unfold_x = unfold_x.reshape(b, c, 1, kernel_size, kernel_size, h, w).repeat(1, 1, up ** 2, 1, 1, 1, 1)unfold_x = unfold_x.reshape(b, c * kernel_size * kernel_size, h, w)unfold_x = F.interpolate(unfold_x, scale_factor=up, mode='nearest')# normed_mask = normed_mask.reshape(b, 1, up ** 2, kernel_size, kernel_size, h, w)unfold_x = unfold_x.reshape(b, c, kernel_size * kernel_size, m_h, m_w)normed_mask = normed_mask.reshape(b, 1, kernel_size * kernel_size, m_h, m_w)res = unfold_x * normed_maskres = res.sum(dim=2).reshape(b, c, m_h, m_w)return resdef normal_init(module, mean=0, std=1, bias=0):if hasattr(module, 'weight') and module.weight is not None:nn.init.normal_(module.weight, mean, std)if hasattr(module, 'bias') and module.bias is not None:nn.init.constant_(module.bias, bias)def constant_init(module, val, bias=0):if hasattr(module, 'weight') and module.weight is not None:nn.init.constant_(module.weight, val)if hasattr(module, 'bias') and module.bias is not None:nn.init.constant_(module.bias, bias)def resize(input,size=None,scale_factor=None,mode='nearest',align_corners=None,warning=True):if warning:if size is not None and align_corners:input_h, input_w = tuple(int(x) for x in input.shape[2:])output_h, output_w = tuple(int(x) for x in size)if output_h > input_h or output_w > input_w:if ((output_h > 1 and output_w > 1 and input_h > 1and input_w > 1) and (output_h - 1) % (input_h - 1)and (output_w - 1) % (input_w - 1)):warnings.warn(f'When align_corners={align_corners}, ''the output would more aligned if 'f'input size {(input_h, input_w)} is `x+1` and 'f'out size {(output_h, output_w)} is `nx+1`')return F.interpolate(input, size, scale_factor, mode, align_corners)def hamming2D(M, N):"""生成二维Hamming窗"""hamming_x = np.hamming(M)hamming_y = np.hamming(N)hamming_2d = np.outer(hamming_x, hamming_y)return hamming_2dclass FreqFusion(nn.Module):def __init__(self,hr_channels,lr_channels,scale_factor=1,lowpass_kernel=5,highpass_kernel=3,up_group=1,encoder_kernel=3,encoder_dilation=1,compressed_channels=64,align_corners=False,upsample_mode='nearest',feature_resample=False, # use offset generator or notfeature_resample_group=4,comp_feat_upsample=True, # use ALPF & AHPF for init upsamplinguse_high_pass=True,use_low_pass=True,hr_residual=True,semi_conv=True,hamming_window=True, # for regularization, do not matter reallyfeature_resample_norm=True,**kwargs):super().__init__()self.scale_factor = scale_factorself.lowpass_kernel = lowpass_kernelself.highpass_kernel = highpass_kernelself.up_group = up_groupself.encoder_kernel = encoder_kernelself.encoder_dilation = encoder_dilationself.compressed_channels = compressed_channelsself.hr_channel_compressor = nn.Conv2d(hr_channels, self.compressed_channels, 1)self.lr_channel_compressor = nn.Conv2d(lr_channels, self.compressed_channels, 1)self.content_encoder = nn.Conv2d( # ALPF generatorself.compressed_channels,lowpass_kernel ** 2 * self.up_group * self.scale_factor * self.scale_factor,self.encoder_kernel,padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),dilation=self.encoder_dilation,groups=1)self.align_corners = align_cornersself.upsample_mode = upsample_modeself.hr_residual = hr_residualself.use_high_pass = use_high_passself.use_low_pass = use_low_passself.semi_conv = semi_convself.feature_resample = feature_resampleself.comp_feat_upsample = comp_feat_upsampleif self.feature_resample:self.dysampler = LocalSimGuidedSampler(in_channels=compressed_channels, scale=2, style='lp',groups=feature_resample_group, use_direct_scale=True,kernel_size=encoder_kernel, norm=feature_resample_norm)if self.use_high_pass:self.content_encoder2 = nn.Conv2d( # AHPF generatorself.compressed_channels,highpass_kernel ** 2 * self.up_group * self.scale_factor * self.scale_factor,self.encoder_kernel,padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),dilation=self.encoder_dilation,groups=1)self.hamming_window = hamming_windowlowpass_pad = 0highpass_pad = 0if self.hamming_window:self.register_buffer('hamming_lowpass', torch.FloatTensor(hamming2D(lowpass_kernel + 2 * lowpass_pad, lowpass_kernel + 2 * lowpass_pad))[None, None,])self.register_buffer('hamming_highpass', torch.FloatTensor(hamming2D(highpass_kernel + 2 * highpass_pad, highpass_kernel + 2 * highpass_pad))[None, None,])else:self.register_buffer('hamming_lowpass', torch.FloatTensor([1.0]))self.register_buffer('hamming_highpass', torch.FloatTensor([1.0]))self.init_weights()def init_weights(self):for m in self.modules():# print(m)if isinstance(m, nn.Conv2d):xavier_init(m, distribution='uniform')normal_init(self.content_encoder, std=0.001)if self.use_high_pass:normal_init(self.content_encoder2, std=0.001)def kernel_normalizer(self, mask, kernel, scale_factor=None, hamming=1):if scale_factor is not None:mask = F.pixel_shuffle(mask, self.scale_factor)n, mask_c, h, w = mask.size()mask_channel = int(mask_c / float(kernel ** 2)) # groupmask = mask.view(n, mask_channel, -1, h, w)mask = F.softmax(mask, dim=2, dtype=mask.dtype)mask = mask.view(n, mask_channel, kernel, kernel, h, w)mask = mask.permute(0, 1, 4, 5, 2, 3).view(n, -1, kernel, kernel)# mask = F.pad(mask, pad=[padding] * 4, mode=self.padding_mode) # kernel + 2 * paddingmask = mask * hammingmask /= mask.sum(dim=(-1, -2), keepdims=True)mask = mask.view(n, mask_channel, h, w, -1)mask = mask.permute(0, 1, 4, 2, 3).view(n, -1, h, w).contiguous()return maskdef forward(self, hr_feat, lr_feat, use_checkpoint=False): # use check_point to save GPU memoryif use_checkpoint:return checkpoint(self._forward, hr_feat, lr_feat)else:return self._forward(hr_feat, lr_feat)def _forward(self, hr_feat, lr_feat):compressed_hr_feat = self.hr_channel_compressor(hr_feat)compressed_lr_feat = self.lr_channel_compressor(lr_feat)if self.semi_conv:if self.comp_feat_upsample:if self.use_high_pass:mask_hr_hr_feat = self.content_encoder2(compressed_hr_feat) # 从hr_feat得到初始高通滤波特征mask_hr_init = self.kernel_normalizer(mask_hr_hr_feat, self.highpass_kernel,hamming=self.hamming_highpass) # kernel归一化得到初始高通滤波compressed_hr_feat = compressed_hr_feat + compressed_hr_feat - carafe(compressed_hr_feat,mask_hr_init,self.highpass_kernel,self.up_group,1) # 利用初始高通滤波对压缩hr_feat的高频增强 (x-x的低通结果=x的高通结果)mask_lr_hr_feat = self.content_encoder(compressed_hr_feat) # 从hr_feat得到初始低通滤波特征mask_lr_init = self.kernel_normalizer(mask_lr_hr_feat, self.lowpass_kernel,hamming=self.hamming_lowpass) # kernel归一化得到初始低通滤波mask_lr_lr_feat_lr = self.content_encoder(compressed_lr_feat) # 从hr_feat得到另一部分初始低通滤波特征mask_lr_lr_feat = F.interpolate( # 利用初始低通滤波对另一部分初始低通滤波特征上采样carafe(mask_lr_lr_feat_lr, mask_lr_init, self.lowpass_kernel, self.up_group, 2),size=compressed_hr_feat.shape[-2:], mode='nearest')mask_lr = mask_lr_hr_feat + mask_lr_lr_feat # 将两部分初始低通滤波特征合在一起mask_lr_init = self.kernel_normalizer(mask_lr, self.lowpass_kernel,hamming=self.hamming_lowpass) # 得到初步融合的初始低通滤波mask_hr_lr_feat = F.interpolate( # 使用初始低通滤波对lr_feat处理,分辨率得到提高carafe(self.content_encoder2(compressed_lr_feat), mask_lr_init, self.lowpass_kernel,self.up_group, 2), size=compressed_hr_feat.shape[-2:], mode='nearest')mask_hr = mask_hr_hr_feat + mask_hr_lr_feat # 最终高通滤波特征else:raise NotImplementedErrorelse:mask_lr = self.content_encoder(compressed_hr_feat) + F.interpolate(self.content_encoder(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest')if self.use_high_pass:mask_hr = self.content_encoder2(compressed_hr_feat) + F.interpolate(self.content_encoder2(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest')else:compressed_x = F.interpolate(compressed_lr_feat, size=compressed_hr_feat.shape[-2:],mode='nearest') + compressed_hr_featmask_lr = self.content_encoder(compressed_x)if self.use_high_pass:mask_hr = self.content_encoder2(compressed_x)mask_lr = self.kernel_normalizer(mask_lr, self.lowpass_kernel, hamming=self.hamming_lowpass)if self.semi_conv:lr_feat = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 2)else:lr_feat = resize(input=lr_feat,size=hr_feat.shape[2:],mode=self.upsample_mode,align_corners=None if self.upsample_mode == 'nearest' else self.align_corners)lr_feat = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 1)if self.use_high_pass:mask_hr = self.kernel_normalizer(mask_hr, self.highpass_kernel, hamming=self.hamming_highpass)hr_feat_hf = hr_feat - carafe(hr_feat, mask_hr, self.highpass_kernel, self.up_group, 1)if self.hr_residual:# print('using hr_residual')hr_feat = hr_feat_hf + hr_featelse:hr_feat = hr_feat_hfif self.feature_resample:# print(lr_feat.shape)lr_feat = self.dysampler(hr_x=compressed_hr_feat,lr_x=compressed_lr_feat, feat2sample=lr_feat)return mask_lr, hr_feat, lr_featclass LocalSimGuidedSampler(nn.Module):"""offset generator in FreqFusion"""def __init__(self, in_channels, scale=2, style='lp', groups=4, use_direct_scale=True, kernel_size=1, local_window=3,sim_type='cos', norm=True, direction_feat='sim_concat'):super().__init__()assert scale == 2assert style == 'lp'self.scale = scaleself.style = styleself.groups = groupsself.local_window = local_windowself.sim_type = sim_typeself.direction_feat = direction_featif style == 'pl':assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0assert in_channels >= groups and in_channels % groups == 0if style == 'pl':in_channels = in_channels // scale ** 2out_channels = 2 * groupselse:out_channels = 2 * groups * scale ** 2if self.direction_feat == 'sim':self.offset = nn.Conv2d(local_window ** 2 - 1, out_channels, kernel_size=kernel_size,padding=kernel_size // 2)elif self.direction_feat == 'sim_concat':self.offset = nn.Conv2d(in_channels + local_window ** 2 - 1, out_channels, kernel_size=kernel_size,padding=kernel_size // 2)else:raise NotImplementedErrornormal_init(self.offset, std=0.001)if use_direct_scale:if self.direction_feat == 'sim':self.direct_scale = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,padding=kernel_size // 2)elif self.direction_feat == 'sim_concat':self.direct_scale = nn.Conv2d(in_channels + local_window ** 2 - 1, out_channels,kernel_size=kernel_size, padding=kernel_size // 2)else:raise NotImplementedErrorconstant_init(self.direct_scale, val=0.)out_channels = 2 * groupsif self.direction_feat == 'sim':self.hr_offset = nn.Conv2d(local_window ** 2 - 1, out_channels, kernel_size=kernel_size,padding=kernel_size // 2)elif self.direction_feat == 'sim_concat':self.hr_offset = nn.Conv2d(in_channels + local_window ** 2 - 1, out_channels, kernel_size=kernel_size,padding=kernel_size // 2)else:raise NotImplementedErrornormal_init(self.hr_offset, std=0.001)if use_direct_scale:if self.direction_feat == 'sim':self.hr_direct_scale = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,padding=kernel_size // 2)elif self.direction_feat == 'sim_concat':self.hr_direct_scale = nn.Conv2d(in_channels + local_window ** 2 - 1, out_channels,kernel_size=kernel_size, padding=kernel_size // 2)else:raise NotImplementedErrorconstant_init(self.hr_direct_scale, val=0.)self.norm = normif self.norm:self.norm_hr = nn.GroupNorm(in_channels // 8, in_channels)self.norm_lr = nn.GroupNorm(in_channels // 8, in_channels)else:self.norm_hr = nn.Identity()self.norm_lr = nn.Identity()self.register_buffer('init_pos', self._init_pos())def _init_pos(self):h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scalereturn torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)def sample(self, x, offset, scale=None):if scale is None: scale = self.scaleB, _, H, W = offset.shapeoffset = offset.view(B, 2, -1, H, W)coords_h = torch.arange(H) + 0.5coords_w = torch.arange(W) + 0.5coords = torch.stack(torch.meshgrid([coords_w, coords_h])).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)coords = 2 * (coords + offset) / normalizer - 1coords = F.pixel_shuffle(coords.view(B, -1, H, W), scale).view(B, 2, -1, scale * H, scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)return F.grid_sample(x.reshape(B * self.groups, -1, x.size(-2), x.size(-1)), coords, mode='bilinear',align_corners=False, padding_mode="border").view(B, -1, scale * H, scale * W)def forward(self, hr_x, lr_x, feat2sample):hr_x = self.norm_hr(hr_x)lr_x = self.norm_lr(lr_x)if self.direction_feat == 'sim':hr_sim = compute_similarity(hr_x, self.local_window, dilation=2, sim='cos')lr_sim = compute_similarity(lr_x, self.local_window, dilation=2, sim='cos')elif self.direction_feat == 'sim_concat':hr_sim = torch.cat([hr_x, compute_similarity(hr_x, self.local_window, dilation=2, sim='cos')], dim=1)lr_sim = torch.cat([lr_x, compute_similarity(lr_x, self.local_window, dilation=2, sim='cos')], dim=1)hr_x, lr_x = hr_sim, lr_sim# offset = self.get_offset(hr_x, lr_x)offset = self.get_offset_lp(hr_x, lr_x, hr_sim, lr_sim)return self.sample(feat2sample, offset)# def get_offset_lp(self, hr_x, lr_x):def get_offset_lp(self, hr_x, lr_x, hr_sim, lr_sim):if hasattr(self, 'direct_scale'):offset = (self.offset(lr_sim) + F.pixel_unshuffle(self.hr_offset(hr_sim), self.scale)) * (self.direct_scale(lr_x) + F.pixel_unshuffle(self.hr_direct_scale(hr_x),self.scale)).sigmoid() + self.init_poselse:offset = (self.offset(lr_x) + F.pixel_unshuffle(self.hr_offset(hr_x), self.scale)) * 0.25 + self.init_posreturn offsetdef get_offset(self, hr_x, lr_x):if self.style == 'pl':raise NotImplementedErrorreturn self.get_offset_lp(hr_x, lr_x)def compute_similarity(input_tensor, k=3, dilation=1, sim='cos'):"""计算输入张量中每一点与周围KxK范围内的点的余弦相似度。参数:- input_tensor: 输入张量,形状为[B, C, H, W]- k: 范围大小,表示周围KxK范围内的点返回:- 输出张量,形状为[B, KxK-1, H, W]"""B, C, H, W = input_tensor.shape# 使用零填充来处理边界情况# padded_input = F.pad(input_tensor, (k // 2, k // 2, k // 2, k // 2), mode='constant', value=0)# 展平输入张量中每个点及其周围KxK范围内的点unfold_tensor = F.unfold(input_tensor, k, padding=(k // 2) * dilation, dilation=dilation) # B, CxKxK, HW# print(unfold_tensor.shape)unfold_tensor = unfold_tensor.reshape(B, C, k ** 2, H, W)# 计算余弦相似度if sim == 'cos':similarity = F.cosine_similarity(unfold_tensor[:, :, k * k // 2:k * k // 2 + 1], unfold_tensor[:, :, :], dim=1)elif sim == 'dot':similarity = unfold_tensor[:, :, k * k // 2:k * k // 2 + 1] * unfold_tensor[:, :, :]similarity = similarity.sum(dim=1)else:raise NotImplementedError# 移除中心点的余弦相似度,得到[KxK-1]的结果similarity = torch.cat((similarity[:, :k * k // 2], similarity[:, k * k // 2 + 1:]), dim=1)# 将结果重塑回[B, KxK-1, H, W]的形状similarity = similarity.view(B, k * k - 1, H, W)return similarityif __name__ == '__main__':x = torch.randn(4, 128, 512, 512).cuda()y = torch.randn(4, 128, 256, 256).cuda()model = FreqFusion(128, 128).cuda()out, _, _ = model(x, y)print(out.shape)