uisrnn.loss_func
Loss functions for training.
1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Loss functions for training.""" 15 16import torch 17 18 19def weighted_mse_loss(input_tensor, target_tensor, weight=1): 20 """Compute weighted MSE loss. 21 22 Note that we are doing weighted loss that only sum up over non-zero entries. 23 24 Args: 25 input_tensor: input tensor 26 target_tensor: target tensor 27 weight: weight tensor, in this case 1/sigma^2 28 29 Returns: 30 the weighted MSE loss 31 """ 32 observation_dim = input_tensor.size()[-1] 33 streched_tensor = ((input_tensor - target_tensor) ** 2).view( 34 -1, observation_dim) 35 entry_num = float(streched_tensor.size()[0]) 36 non_zero_entry_num = torch.sum(streched_tensor[:, 0] != 0).float() 37 weighted_tensor = torch.mm( 38 ((input_tensor - target_tensor)**2).view(-1, observation_dim), 39 (torch.diag(weight.float().view(-1)))) 40 return torch.mean( 41 weighted_tensor) * weight.nelement() * entry_num / non_zero_entry_num 42 43 44def sigma2_prior_loss(num_non_zero, sigma_alpha, sigma_beta, sigma2): 45 """Compute sigma2 prior loss. 46 47 Args: 48 num_non_zero: since rnn_truth is a collection of different length sequences 49 padded with zeros to fit them into a tensor, we count the sum of 50 'real lengths' of all sequences 51 sigma_alpha: inverse gamma shape 52 sigma_beta: inverse gamma scale 53 sigma2: sigma squared 54 55 Returns: 56 the sigma2 prior loss 57 """ 58 return ((2 * sigma_alpha + num_non_zero + 2) / 59 (2 * num_non_zero) * torch.log(sigma2)).sum() + ( 60 sigma_beta / (sigma2 * num_non_zero)).sum() 61 62 63def regularization_loss(params, weight): 64 """Compute regularization loss. 65 66 Args: 67 params: iterable of all parameters 68 weight: weight for the regularization term 69 70 Returns: 71 the regularization loss 72 """ 73 l2_reg = 0 74 for param in params: 75 l2_reg += torch.norm(param) 76 return weight * l2_reg
20def weighted_mse_loss(input_tensor, target_tensor, weight=1): 21 """Compute weighted MSE loss. 22 23 Note that we are doing weighted loss that only sum up over non-zero entries. 24 25 Args: 26 input_tensor: input tensor 27 target_tensor: target tensor 28 weight: weight tensor, in this case 1/sigma^2 29 30 Returns: 31 the weighted MSE loss 32 """ 33 observation_dim = input_tensor.size()[-1] 34 streched_tensor = ((input_tensor - target_tensor) ** 2).view( 35 -1, observation_dim) 36 entry_num = float(streched_tensor.size()[0]) 37 non_zero_entry_num = torch.sum(streched_tensor[:, 0] != 0).float() 38 weighted_tensor = torch.mm( 39 ((input_tensor - target_tensor)**2).view(-1, observation_dim), 40 (torch.diag(weight.float().view(-1)))) 41 return torch.mean( 42 weighted_tensor) * weight.nelement() * entry_num / non_zero_entry_num
Compute weighted MSE loss.
Note that we are doing weighted loss that only sum up over non-zero entries.
Args: input_tensor: input tensor target_tensor: target tensor weight: weight tensor, in this case 1/sigma^2
Returns: the weighted MSE loss
45def sigma2_prior_loss(num_non_zero, sigma_alpha, sigma_beta, sigma2): 46 """Compute sigma2 prior loss. 47 48 Args: 49 num_non_zero: since rnn_truth is a collection of different length sequences 50 padded with zeros to fit them into a tensor, we count the sum of 51 'real lengths' of all sequences 52 sigma_alpha: inverse gamma shape 53 sigma_beta: inverse gamma scale 54 sigma2: sigma squared 55 56 Returns: 57 the sigma2 prior loss 58 """ 59 return ((2 * sigma_alpha + num_non_zero + 2) / 60 (2 * num_non_zero) * torch.log(sigma2)).sum() + ( 61 sigma_beta / (sigma2 * num_non_zero)).sum()
Compute sigma2 prior loss.
Args: num_non_zero: since rnn_truth is a collection of different length sequences padded with zeros to fit them into a tensor, we count the sum of 'real lengths' of all sequences sigma_alpha: inverse gamma shape sigma_beta: inverse gamma scale sigma2: sigma squared
Returns: the sigma2 prior loss
64def regularization_loss(params, weight): 65 """Compute regularization loss. 66 67 Args: 68 params: iterable of all parameters 69 weight: weight for the regularization term 70 71 Returns: 72 the regularization loss 73 """ 74 l2_reg = 0 75 for param in params: 76 l2_reg += torch.norm(param) 77 return weight * l2_reg
Compute regularization loss.
Args: params: iterable of all parameters weight: weight for the regularization term
Returns: the regularization loss