uisrnn.loss_func

Loss functions for training.

 1# Copyright 2018 Google LLC
 2#
 3# Licensed under the Apache License, Version 2.0 (the "License");
 4# you may not use this file except in compliance with the License.
 5# You may obtain a copy of the License at
 6#
 7#     https://www.apache.org/licenses/LICENSE-2.0
 8#
 9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Loss functions for training."""
15
16import torch
17
18
19def weighted_mse_loss(input_tensor, target_tensor, weight=1):
20  """Compute weighted MSE loss.
21
22  Note that we are doing weighted loss that only sum up over non-zero entries.
23
24  Args:
25    input_tensor: input tensor
26    target_tensor: target tensor
27    weight: weight tensor, in this case 1/sigma^2
28
29  Returns:
30    the weighted MSE loss
31  """
32  observation_dim = input_tensor.size()[-1]
33  streched_tensor = ((input_tensor - target_tensor) ** 2).view(
34      -1, observation_dim)
35  entry_num = float(streched_tensor.size()[0])
36  non_zero_entry_num = torch.sum(streched_tensor[:, 0] != 0).float()
37  weighted_tensor = torch.mm(
38      ((input_tensor - target_tensor)**2).view(-1, observation_dim),
39      (torch.diag(weight.float().view(-1))))
40  return torch.mean(
41      weighted_tensor) * weight.nelement() * entry_num / non_zero_entry_num
42
43
44def sigma2_prior_loss(num_non_zero, sigma_alpha, sigma_beta, sigma2):
45  """Compute sigma2 prior loss.
46
47  Args:
48    num_non_zero: since rnn_truth is a collection of different length sequences
49        padded with zeros to fit them into a tensor, we count the sum of
50        'real lengths' of all sequences
51    sigma_alpha: inverse gamma shape
52    sigma_beta: inverse gamma scale
53    sigma2: sigma squared
54
55  Returns:
56    the sigma2 prior loss
57  """
58  return ((2 * sigma_alpha + num_non_zero + 2) /
59          (2 * num_non_zero) * torch.log(sigma2)).sum() + (
60              sigma_beta / (sigma2 * num_non_zero)).sum()
61
62
63def regularization_loss(params, weight):
64  """Compute regularization loss.
65
66  Args:
67    params: iterable of all parameters
68    weight: weight for the regularization term
69
70  Returns:
71    the regularization loss
72  """
73  l2_reg = 0
74  for param in params:
75    l2_reg += torch.norm(param)
76  return weight * l2_reg
def weighted_mse_loss(input_tensor, target_tensor, weight=1):
20def weighted_mse_loss(input_tensor, target_tensor, weight=1):
21  """Compute weighted MSE loss.
22
23  Note that we are doing weighted loss that only sum up over non-zero entries.
24
25  Args:
26    input_tensor: input tensor
27    target_tensor: target tensor
28    weight: weight tensor, in this case 1/sigma^2
29
30  Returns:
31    the weighted MSE loss
32  """
33  observation_dim = input_tensor.size()[-1]
34  streched_tensor = ((input_tensor - target_tensor) ** 2).view(
35      -1, observation_dim)
36  entry_num = float(streched_tensor.size()[0])
37  non_zero_entry_num = torch.sum(streched_tensor[:, 0] != 0).float()
38  weighted_tensor = torch.mm(
39      ((input_tensor - target_tensor)**2).view(-1, observation_dim),
40      (torch.diag(weight.float().view(-1))))
41  return torch.mean(
42      weighted_tensor) * weight.nelement() * entry_num / non_zero_entry_num

Compute weighted MSE loss.

Note that we are doing weighted loss that only sum up over non-zero entries.

Args: input_tensor: input tensor target_tensor: target tensor weight: weight tensor, in this case 1/sigma^2

Returns: the weighted MSE loss

def sigma2_prior_loss(num_non_zero, sigma_alpha, sigma_beta, sigma2):
45def sigma2_prior_loss(num_non_zero, sigma_alpha, sigma_beta, sigma2):
46  """Compute sigma2 prior loss.
47
48  Args:
49    num_non_zero: since rnn_truth is a collection of different length sequences
50        padded with zeros to fit them into a tensor, we count the sum of
51        'real lengths' of all sequences
52    sigma_alpha: inverse gamma shape
53    sigma_beta: inverse gamma scale
54    sigma2: sigma squared
55
56  Returns:
57    the sigma2 prior loss
58  """
59  return ((2 * sigma_alpha + num_non_zero + 2) /
60          (2 * num_non_zero) * torch.log(sigma2)).sum() + (
61              sigma_beta / (sigma2 * num_non_zero)).sum()

Compute sigma2 prior loss.

Args: num_non_zero: since rnn_truth is a collection of different length sequences padded with zeros to fit them into a tensor, we count the sum of 'real lengths' of all sequences sigma_alpha: inverse gamma shape sigma_beta: inverse gamma scale sigma2: sigma squared

Returns: the sigma2 prior loss

def regularization_loss(params, weight):
64def regularization_loss(params, weight):
65  """Compute regularization loss.
66
67  Args:
68    params: iterable of all parameters
69    weight: weight for the regularization term
70
71  Returns:
72    the regularization loss
73  """
74  l2_reg = 0
75  for param in params:
76    l2_reg += torch.norm(param)
77  return weight * l2_reg

Compute regularization loss.

Args: params: iterable of all parameters weight: weight for the regularization term

Returns: the regularization loss