Source code for decode.neuralfitter.utils.last_layer_dynamics

from typing import Tuple

import torch


[docs]def weight_by_gradient(layer: torch.nn.ModuleList, loss: torch.Tensor, optimizer: torch.optim.Optimizer) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: layer: module layers loss: not reduced loss values optimizer: optimizer Returns: weight_cX_h1_w1: weight per channel (1x C x 1 x 1) loss_ch: channel-wise loss loss_w: weighted loss """ """ Reduce NCHW channel wise. Division over numel and multiply by ch_out is not needed inside this method, but if you want to use loss_wch, or loss_ch directly the numbers would be off by a factor """ ch_out = len(layer) loss_ch = loss.sum(-1).sum(-1).sum(0) / loss.numel() * ch_out head_grads = torch.zeros((ch_out,)).to(loss.device) weighting = torch.ones_like(head_grads).to(loss.device) for i in range(ch_out): head_grads[i] = torch.autograd.grad(loss_ch[i], layer[i].out_conv.weight, retain_graph=True)[0].abs().sum() """Kill the channels which are completely inactive""" ix_on = head_grads != 0. weighting[~ix_on] = 0. # set excluded to zero optimizer.zero_grad() N = (1 / head_grads[ix_on]).sum() weighting[ix_on] = weighting[ix_on] / head_grads[ix_on] weighting = weighting / N loss_wch = (loss_ch * weighting).sum() weight_cX_h1_w1 = weighting.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # weight tensor of size 1 x C x 1 x 1 return weight_cX_h1_w1, loss_ch, loss_wch