uisrnn

The module for Unbounded Interleaved-State Recurrent Neural Network.

An introduction is available at [README.md].

 1# Copyright 2018 Google LLC
 2#
 3# Licensed under the Apache License, Version 2.0 (the "License");
 4# you may not use this file except in compliance with the License.
 5# You may obtain a copy of the License at
 6#
 7#     https://www.apache.org/licenses/LICENSE-2.0
 8#
 9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""The module for Unbounded Interleaved-State Recurrent Neural Network.
15
16An introduction is available at [README.md].
17
18[README.md]: https://github.com/google/uis-rnn/blob/master/README.md
19"""
20
21from . import arguments
22from . import evals
23from . import uisrnn
24from . import utils
25
26parse_arguments = arguments.parse_arguments
27compute_sequence_match_accuracy = evals.compute_sequence_match_accuracy
28output_result = utils.output_result
29UISRNN = uisrnn.UISRNN
30parallel_predict = uisrnn.parallel_predict
def parse_arguments():
 31def parse_arguments():
 32  """Parse arguments.
 33
 34  Returns:
 35    A tuple of:
 36
 37      - `model_args`: model arguments
 38      - `training_args`: training arguments
 39      - `inference_args`: inference arguments
 40  """
 41  # model configurations
 42  model_parser = argparse.ArgumentParser(
 43      description='Model configurations.', add_help=False)
 44
 45  model_parser.add_argument(
 46      '--observation_dim',
 47      default=_DEFAULT_OBSERVATION_DIM,
 48      type=int,
 49      help='The dimension of the embeddings (e.g. d-vectors).')
 50
 51  model_parser.add_argument(
 52      '--rnn_hidden_size',
 53      default=512,
 54      type=int,
 55      help='The number of nodes for each RNN layer.')
 56  model_parser.add_argument(
 57      '--rnn_depth',
 58      default=1,
 59      type=int,
 60      help='The number of RNN layers.')
 61  model_parser.add_argument(
 62      '--rnn_dropout',
 63      default=0.2,
 64      type=float,
 65      help='The dropout rate for all RNN layers.')
 66  model_parser.add_argument(
 67      '--transition_bias',
 68      default=None,
 69      type=float,
 70      help='The value of p0, corresponding to Eq. (6) in the '
 71           'paper. If the value is given, we will fix to this value. If the '
 72           'value is None, we will estimate it from training data '
 73           'using Eq. (13) in the paper.')
 74  model_parser.add_argument(
 75      '--crp_alpha',
 76      default=1.0,
 77      type=float,
 78      help='The value of alpha for the Chinese restaurant process (CRP), '
 79           'corresponding to Eq. (7) in the paper. In this open source '
 80           'implementation, currently we only support using a given value '
 81           'of crp_alpha.')
 82  model_parser.add_argument(
 83      '--sigma2',
 84      default=None,
 85      type=float,
 86      help='The value of sigma squared, corresponding to Eq. (11) in the '
 87           'paper. If the value is given, we will fix to this value. If the '
 88           'value is None, we will estimate it from training data.')
 89  model_parser.add_argument(
 90      '--verbosity',
 91      default=3,
 92      type=int,
 93      help='How verbose will the logging information be. Higher value '
 94      'represents more verbose information. A general guideline: '
 95      '0 for fatals; 1 for errors; 2 for finishing important steps; '
 96      '3 for finishing less important steps; 4 or above for debugging '
 97      'information.')
 98  model_parser.add_argument(
 99      '--enable_cuda',
100      default=True,
101      type=str2bool,
102      help='Whether we should use CUDA if it is avaiable. If False, we will '
103      'always use CPU.')
104
105  # training configurations
106  training_parser = argparse.ArgumentParser(
107      description='Training configurations.', add_help=False)
108
109  training_parser.add_argument(
110      '--optimizer',
111      '-o',
112      default='adam',
113      choices=['adam'],
114      help='The optimizer for training.')
115  training_parser.add_argument(
116      '--learning_rate',
117      '-l',
118      default=1e-3,
119      type=float,
120      help='The leaning rate for training.')
121  training_parser.add_argument(
122      '--train_iteration',
123      '-t',
124      default=20000,
125      type=int,
126      help='The total number of training iterations.')
127  training_parser.add_argument(
128      '--batch_size',
129      '-b',
130      default=10,
131      type=int,
132      help='The batch size for training.')
133  training_parser.add_argument(
134      '--num_permutations',
135      default=10,
136      type=int,
137      help='The number of permutations per utterance sampled in the training '
138           'data.')
139  training_parser.add_argument(
140      '--sigma_alpha',
141      default=1.0,
142      type=float,
143      help='The inverse gamma shape for estimating sigma2. This value is only '
144           'meaningful when sigma2 is not given, and estimated from data.')
145  training_parser.add_argument(
146      '--sigma_beta',
147      default=1.0,
148      type=float,
149      help='The inverse gamma scale for estimating sigma2. This value is only '
150           'meaningful when sigma2 is not given, and estimated from data.')
151  training_parser.add_argument(
152      '--regularization_weight',
153      '-r',
154      default=1e-5,
155      type=float,
156      help='The network regularization multiplicative.')
157  training_parser.add_argument(
158      '--grad_max_norm',
159      default=5.0,
160      type=float,
161      help='Max norm of the gradient.')
162  training_parser.add_argument(
163      '--enforce_cluster_id_uniqueness',
164      default=True,
165      type=str2bool,
166      help='Whether to enforce cluster ID uniqueness across different '
167           'training sequences. Only effective when the first input to fit() '
168           'is a list of sequences. In general, assume the cluster IDs for two '
169           'sequences are [a, b] and [a, c]. If the `a` from the two sequences '
170           'are not the same label, then this arg should be True.')
171
172  # inference configurations
173  inference_parser = argparse.ArgumentParser(
174      description='Inference configurations.', add_help=False)
175
176  inference_parser.add_argument(
177      '--beam_size',
178      '-s',
179      default=10,
180      type=int,
181      help='The beam search size for inference.')
182  inference_parser.add_argument(
183      '--look_ahead',
184      default=1,
185      type=int,
186      help='The number of look ahead steps during inference.')
187  inference_parser.add_argument(
188      '--test_iteration',
189      default=2,
190      type=int,
191      help='During inference, we concatenate M duplicates of the test '
192           'sequence, and run inference on this concatenated sequence. '
193           'Then we return the inference results on the last duplicate as the '
194           'final prediction for the test sequence.')
195
196  # a super parser for sanity checks
197  super_parser = argparse.ArgumentParser(
198      parents=[model_parser, training_parser, inference_parser])
199
200  # get arguments
201  super_parser.parse_args()
202  model_args, _ = model_parser.parse_known_args()
203  training_args, _ = training_parser.parse_known_args()
204  inference_args, _ = inference_parser.parse_known_args()
205
206  return (model_args, training_args, inference_args)

Parse arguments.

Returns: A tuple of:

- `model_args`: model arguments
- `training_args`: training arguments
- `inference_args`: inference arguments
def compute_sequence_match_accuracy(sequence1, sequence2):
41def compute_sequence_match_accuracy(sequence1, sequence2):
42  """Compute the accuracy between two sequences by finding optimal matching.
43
44  Args:
45    sequence1: A list of integers or strings.
46    sequence2: A list of integers or strings.
47
48  Returns:
49    accuracy: sequence matching accuracy as a number in [0.0, 1.0]
50
51  Raises:
52    TypeError: If sequence1 or sequence2 is not list.
53    ValueError: If sequence1 and sequence2 are not same size.
54  """
55  if not isinstance(sequence1, list) or not isinstance(sequence2, list):
56    raise TypeError('sequence1 and sequence2 must be lists')
57  if not sequence1 or len(sequence1) != len(sequence2):
58    raise ValueError(
59        'sequence1 and sequence2 must have the same non-zero length')
60  # get unique ids from sequences
61  unique_ids1 = sorted(set(sequence1))
62  unique_ids2 = sorted(set(sequence2))
63  inverse_index1 = get_list_inverse_index(unique_ids1)
64  inverse_index2 = get_list_inverse_index(unique_ids2)
65  # get the count matrix
66  count_matrix = np.zeros((len(unique_ids1), len(unique_ids2)))
67  for item1, item2 in zip(sequence1, sequence2):
68    index1 = inverse_index1[item1]
69    index2 = inverse_index2[item2]
70    count_matrix[index1, index2] += 1.0
71  row_index, col_index = optimize.linear_sum_assignment(-count_matrix)
72  optimal_match_count = count_matrix[row_index, col_index].sum()
73  accuracy = optimal_match_count / len(sequence1)
74  return accuracy

Compute the accuracy between two sequences by finding optimal matching.

Args: sequence1: A list of integers or strings. sequence2: A list of integers or strings.

Returns: accuracy: sequence matching accuracy as a number in [0.0, 1.0]

Raises: TypeError: If sequence1 or sequence2 is not list. ValueError: If sequence1 and sequence2 are not same size.

def output_result(model_args, training_args, test_record):
254def output_result(model_args, training_args, test_record):
255  """Produce a string to summarize the experiment."""
256  accuracy_array, _ = zip(*test_record)
257  total_accuracy = np.mean(accuracy_array)
258  output_string = """
259Config:
260  sigma_alpha: {}
261  sigma_beta: {}
262  crp_alpha: {}
263  learning rate: {}
264  regularization: {}
265  batch size: {}
266
267Performance:
268  averaged accuracy: {:.6f}
269  accuracy numbers for all testing sequences:
270  """.strip().format(
271      training_args.sigma_alpha,
272      training_args.sigma_beta,
273      model_args.crp_alpha,
274      training_args.learning_rate,
275      training_args.regularization_weight,
276      training_args.batch_size,
277      total_accuracy)
278  for accuracy in accuracy_array:
279    output_string += '\n    {:.6f}'.format(accuracy)
280  output_string += '\n' + '=' * 80 + '\n'
281  filename = 'layer_{}_{}_{:.1f}_result.txt'.format(
282      model_args.rnn_hidden_size,
283      model_args.rnn_depth, model_args.rnn_dropout)
284  with open(filename, 'a') as file_object:
285    file_object.write(output_string)
286  return output_string

Produce a string to summarize the experiment.

class UISRNN:
 81class UISRNN:
 82  """Unbounded Interleaved-State Recurrent Neural Networks."""
 83
 84  def __init__(self, args):
 85    """Construct the UISRNN object.
 86
 87    Args:
 88      args: Model configurations. See `arguments.py` for details.
 89    """
 90    self.observation_dim = args.observation_dim
 91    self.device = torch.device(
 92        'cuda:0' if (torch.cuda.is_available() and args.enable_cuda) else 'cpu')
 93    self.rnn_model = CoreRNN(self.observation_dim, args.rnn_hidden_size,
 94                             args.rnn_depth, self.observation_dim,
 95                             args.rnn_dropout).to(self.device)
 96    self.rnn_init_hidden = nn.Parameter(
 97        torch.zeros(args.rnn_depth, 1, args.rnn_hidden_size).to(self.device))
 98    # booleans indicating which variables are trainable
 99    self.estimate_sigma2 = (args.sigma2 is None)
100    self.estimate_transition_bias = (args.transition_bias is None)
101    # initial values of variables
102    sigma2 = _INITIAL_SIGMA2_VALUE if self.estimate_sigma2 else args.sigma2
103    self.sigma2 = nn.Parameter(
104        sigma2 * torch.ones(self.observation_dim).to(self.device))
105    self.transition_bias = args.transition_bias
106    self.transition_bias_denominator = 0.0
107    self.crp_alpha = args.crp_alpha
108    self.logger = colortimelog.Logger(args.verbosity)
109
110  def _get_optimizer(self, optimizer, learning_rate):
111    """Get optimizer for UISRNN.
112
113    Args:
114      optimizer: string - name of the optimizer.
115      learning_rate: - learning rate for the entire model.
116        We do not customize learning rate for separate parts.
117
118    Returns:
119      a pytorch "optim" object
120    """
121    params = [
122        {
123            'params': self.rnn_model.parameters()
124        },  # rnn parameters
125        {
126            'params': self.rnn_init_hidden
127        }  # rnn initial hidden state
128    ]
129    if self.estimate_sigma2:  # train sigma2
130      params.append({
131          'params': self.sigma2
132      })  # variance parameters
133    assert optimizer == 'adam', 'Only adam optimizer is supported.'
134    return optim.Adam(params, lr=learning_rate)
135
136  def save(self, filepath):
137    """Save the model to a file.
138
139    Args:
140      filepath: the path of the file.
141    """
142    torch.save({
143        'rnn_state_dict': self.rnn_model.state_dict(),
144        'rnn_init_hidden': self.rnn_init_hidden.detach().cpu().numpy(),
145        'transition_bias': self.transition_bias,
146        'transition_bias_denominator': self.transition_bias_denominator,
147        'crp_alpha': self.crp_alpha,
148        'sigma2': self.sigma2.detach().cpu().numpy()}, filepath)
149
150  def load(self, filepath):
151    """Load the model from a file.
152
153    Args:
154      filepath: the path of the file.
155    """
156    var_dict = torch.load(filepath)
157    self.rnn_model.load_state_dict(var_dict['rnn_state_dict'])
158    self.rnn_init_hidden = nn.Parameter(
159        torch.from_numpy(var_dict['rnn_init_hidden']).to(self.device))
160    self.transition_bias = float(var_dict['transition_bias'])
161    self.transition_bias_denominator = float(
162        var_dict['transition_bias_denominator'])
163    self.crp_alpha = float(var_dict['crp_alpha'])
164    self.sigma2 = nn.Parameter(
165        torch.from_numpy(var_dict['sigma2']).to(self.device))
166
167    self.logger.print(
168        3, 'Loaded model with transition_bias={}, crp_alpha={}, sigma2={}, '
169        'rnn_init_hidden={}'.format(
170            self.transition_bias, self.crp_alpha, var_dict['sigma2'],
171            var_dict['rnn_init_hidden']))
172
173  def fit_concatenated(self, train_sequence, train_cluster_id, args):
174    """Fit UISRNN model to concatenated sequence and cluster_id.
175
176    Args:
177      train_sequence: the training observation sequence, which is a
178        2-dim numpy array of real numbers, of size `N * D`.
179
180        - `N`: summation of lengths of all utterances.
181        - `D`: observation dimension.
182
183        For example,
184      ```
185      train_sequence =
186      [[1.2 3.0 -4.1 6.0]    --> an entry of speaker #0 from utterance 'iaaa'
187       [0.8 -1.1 0.4 0.5]    --> an entry of speaker #1 from utterance 'iaaa'
188       [-0.2 1.0 3.8 5.7]    --> an entry of speaker #0 from utterance 'iaaa'
189       [3.8 -0.1 1.5 2.3]    --> an entry of speaker #0 from utterance 'ibbb'
190       [1.2 1.4 3.6 -2.7]]   --> an entry of speaker #0 from utterance 'ibbb'
191      ```
192        Here `N=5`, `D=4`.
193
194        We concatenate all training utterances into this single sequence.
195      train_cluster_id: the speaker id sequence, which is 1-dim list or
196        numpy array of strings, of size `N`.
197        For example,
198      ```
199      train_cluster_id =
200        ['iaaa_0', 'iaaa_1', 'iaaa_0', 'ibbb_0', 'ibbb_0']
201      ```
202        'iaaa_0' means the entry belongs to speaker #0 in utterance 'iaaa'.
203
204        Note that the order of entries within an utterance are preserved,
205        and all utterances are simply concatenated together.
206      args: Training configurations. See `arguments.py` for details.
207
208    Raises:
209      TypeError: If train_sequence or train_cluster_id is of wrong type.
210      ValueError: If train_sequence or train_cluster_id has wrong dimension.
211    """
212    # check type
213    if (not isinstance(train_sequence, np.ndarray) or
214        train_sequence.dtype != float):
215      raise TypeError('train_sequence should be a numpy array of float type.')
216    if isinstance(train_cluster_id, list):
217      train_cluster_id = np.array(train_cluster_id)
218    if (not isinstance(train_cluster_id, np.ndarray) or
219        not train_cluster_id.dtype.name.startswith(('str', 'unicode'))):
220      raise TypeError('train_cluster_id type be a numpy array of strings.')
221    # check dimension
222    if train_sequence.ndim != 2:
223      raise ValueError('train_sequence must be 2-dim array.')
224    if train_cluster_id.ndim != 1:
225      raise ValueError('train_cluster_id must be 1-dim array.')
226    # check length and size
227    train_total_length, observation_dim = train_sequence.shape
228    if observation_dim != self.observation_dim:
229      raise ValueError('train_sequence does not match the dimension specified '
230                       'by args.observation_dim.')
231    if train_total_length != len(train_cluster_id):
232      raise ValueError('train_sequence length is not equal to '
233                       'train_cluster_id length.')
234
235    self.rnn_model.train()
236    optimizer = self._get_optimizer(optimizer=args.optimizer,
237                                    learning_rate=args.learning_rate)
238
239    sub_sequences, seq_lengths = utils.resize_sequence(
240        sequence=train_sequence,
241        cluster_id=train_cluster_id,
242        num_permutations=args.num_permutations)
243
244    # For batch learning, pack the entire dataset.
245    if args.batch_size is None:
246      packed_train_sequence, rnn_truth = utils.pack_sequence(
247          sub_sequences,
248          seq_lengths,
249          args.batch_size,
250          self.observation_dim,
251          self.device)
252    train_loss = []
253    for num_iter in range(args.train_iteration):
254      optimizer.zero_grad()
255      # For online learning, pack a subset in each iteration.
256      if args.batch_size is not None:
257        packed_train_sequence, rnn_truth = utils.pack_sequence(
258            sub_sequences,
259            seq_lengths,
260            args.batch_size,
261            self.observation_dim,
262            self.device)
263      hidden = self.rnn_init_hidden.repeat(1, args.batch_size, 1)
264      mean, _ = self.rnn_model(packed_train_sequence, hidden)
265      # use mean to predict
266      mean = torch.cumsum(mean, dim=0)
267      mean_size = mean.size()
268      mean = torch.mm(
269          torch.diag(
270              1.0 / torch.arange(1, mean_size[0] + 1).float().to(self.device)),
271          mean.view(mean_size[0], -1))
272      mean = mean.view(mean_size)
273
274      # Likelihood part.
275      loss1 = loss_func.weighted_mse_loss(
276          input_tensor=(rnn_truth != 0).float() * mean[:-1, :, :],
277          target_tensor=rnn_truth,
278          weight=1 / (2 * self.sigma2))
279
280      # Sigma2 prior part.
281      weight = (((rnn_truth != 0).float() * mean[:-1, :, :] - rnn_truth)
282                ** 2).view(-1, observation_dim)
283      num_non_zero = torch.sum((weight != 0).float(), dim=0).squeeze()
284      loss2 = loss_func.sigma2_prior_loss(
285          num_non_zero, args.sigma_alpha, args.sigma_beta, self.sigma2)
286
287      # Regularization part.
288      loss3 = loss_func.regularization_loss(
289          self.rnn_model.parameters(), args.regularization_weight)
290
291      loss = loss1 + loss2 + loss3
292      loss.backward()
293      nn.utils.clip_grad_norm_(self.rnn_model.parameters(), args.grad_max_norm)
294      optimizer.step()
295      # avoid numerical issues
296      self.sigma2.data.clamp_(min=1e-6)
297
298      if (np.remainder(num_iter, 10) == 0 or
299          num_iter == args.train_iteration - 1):
300        self.logger.print(
301            2,
302            'Iter: {:d}  \t'
303            'Training Loss: {:.4f}    \n'
304            '    Negative Log Likelihood: {:.4f}\t'
305            'Sigma2 Prior: {:.4f}\t'
306            'Regularization: {:.4f}'.format(
307                num_iter,
308                float(loss.data),
309                float(loss1.data),
310                float(loss2.data),
311                float(loss3.data)))
312      train_loss.append(float(loss1.data))  # only save the likelihood part
313    self.logger.print(
314        1, 'Done training with {} iterations'.format(args.train_iteration))
315
316  def fit(self, train_sequences, train_cluster_ids, args):
317    """Fit UISRNN model.
318
319    Args:
320      train_sequences: Either a list of training sequences, or a single
321        concatenated training sequence:
322
323        1. train_sequences is list, and each element is a 2-dim numpy array
324           of real numbers, of size: `length * D`.
325           The length varies among different sequences, but the D is the same.
326           In speaker diarization, each sequence is the sequence of speaker
327           embeddings of one utterance.
328        2. train_sequences is a single concatenated sequence, which is a
329           2-dim numpy array of real numbers. See `fit_concatenated()`
330           for more details.
331      train_cluster_ids: Ground truth labels for train_sequences:
332
333        1. if train_sequences is a list, this must also be a list of the same
334           size, each element being a 1-dim list or numpy array of strings.
335        2. if train_sequences is a single concatenated sequence, this
336           must also be the concatenated 1-dim list or numpy array of strings
337      args: Training configurations. See `arguments.py` for details.
338
339    Raises:
340      TypeError: If train_sequences or train_cluster_ids is of wrong type.
341    """
342    if isinstance(train_sequences, np.ndarray):
343      # train_sequences is already the concatenated sequence
344      if self.estimate_transition_bias:
345        # see issue #55: https://github.com/google/uis-rnn/issues/55
346        self.logger.print(
347            2,
348            'Warning: transition_bias cannot be correctly estimated from a '
349            'concatenated sequence; train_sequences will be treated as a '
350            'single sequence. This can lead to inaccurate estimation of '
351            'transition_bias. Please, consider estimating transition_bias '
352            'before concatenating the sequences and passing it as argument.')
353      train_sequences = [train_sequences]
354      train_cluster_ids = [train_cluster_ids]
355    elif isinstance(train_sequences, list):
356      # train_sequences is a list of un-concatenated sequences
357      # we will concatenate it later, after estimating transition_bias
358      pass
359    else:
360      raise TypeError('train_sequences must be a list or numpy.ndarray')
361
362    # estimate transition_bias
363    if self.estimate_transition_bias:
364      (transition_bias,
365       transition_bias_denominator) = utils.estimate_transition_bias(
366           train_cluster_ids)
367      # set or update transition_bias
368      if self.transition_bias is None:
369        self.transition_bias = transition_bias
370        self.transition_bias_denominator = transition_bias_denominator
371      else:
372        self.transition_bias = (
373            self.transition_bias * self.transition_bias_denominator +
374            transition_bias * transition_bias_denominator) / (
375                self.transition_bias_denominator + transition_bias_denominator)
376        self.transition_bias_denominator += transition_bias_denominator
377
378    # concatenate train_sequences
379    (concatenated_train_sequence,
380     concatenated_train_cluster_id) = utils.concatenate_training_data(
381         train_sequences,
382         train_cluster_ids,
383         args.enforce_cluster_id_uniqueness,
384         True)
385
386    self.fit_concatenated(
387        concatenated_train_sequence, concatenated_train_cluster_id, args)
388
389  def _update_beam_state(self, beam_state, look_ahead_seq, cluster_seq):
390    """Update a beam state given a look ahead sequence and known cluster
391    assignments.
392
393    Args:
394      beam_state: A BeamState object.
395      look_ahead_seq: Look ahead sequence, size: look_ahead*D.
396        look_ahead: number of step to look ahead in the beam search.
397        D: observation dimension
398      cluster_seq: Cluster assignment sequence for look_ahead_seq.
399
400    Returns:
401      new_beam_state: An updated BeamState object.
402    """
403
404    loss = 0
405    new_beam_state = BeamState(beam_state)
406    for sub_idx, cluster in enumerate(cluster_seq):
407      if cluster > len(new_beam_state.mean_set):  # invalid trace
408        new_beam_state.neg_likelihood = float('inf')
409        break
410      elif cluster < len(new_beam_state.mean_set):  # existing cluster
411        last_cluster = new_beam_state.trace[-1]
412        loss = loss_func.weighted_mse_loss(
413            input_tensor=torch.squeeze(new_beam_state.mean_set[cluster]),
414            target_tensor=look_ahead_seq[sub_idx, :],
415            weight=1 / (2 * self.sigma2)).cpu().detach().numpy()
416        if cluster == last_cluster:
417          loss -= np.log(1 - self.transition_bias)
418        else:
419          loss -= np.log(self.transition_bias) + np.log(
420              new_beam_state.block_counts[cluster]) - np.log(
421                  sum(new_beam_state.block_counts) + self.crp_alpha)
422        # update new mean and new hidden
423        mean, hidden = self.rnn_model(
424            look_ahead_seq[sub_idx, :].unsqueeze(0).unsqueeze(0),
425            new_beam_state.hidden_set[cluster])
426        new_beam_state.mean_set[cluster] = (new_beam_state.mean_set[cluster]*(
427            (np.array(new_beam_state.trace) == cluster).sum() -
428            1).astype(float) + mean.clone()) / (
429                np.array(new_beam_state.trace) == cluster).sum().astype(
430                    float)  # use mean to predict
431        new_beam_state.hidden_set[cluster] = hidden.clone()
432        if cluster != last_cluster:
433          new_beam_state.block_counts[cluster] += 1
434        new_beam_state.trace.append(cluster)
435      else:  # new cluster
436        init_input = autograd.Variable(
437            torch.zeros(self.observation_dim)
438        ).unsqueeze(0).unsqueeze(0).to(self.device)
439        mean, hidden = self.rnn_model(init_input,
440                                      self.rnn_init_hidden)
441        loss = loss_func.weighted_mse_loss(
442            input_tensor=torch.squeeze(mean),
443            target_tensor=look_ahead_seq[sub_idx, :],
444            weight=1 / (2 * self.sigma2)).cpu().detach().numpy()
445        loss -= np.log(self.transition_bias) + np.log(
446            self.crp_alpha) - np.log(
447                sum(new_beam_state.block_counts) + self.crp_alpha)
448        # update new min and new hidden
449        mean, hidden = self.rnn_model(
450            look_ahead_seq[sub_idx, :].unsqueeze(0).unsqueeze(0),
451            hidden)
452        new_beam_state.append(mean, hidden, cluster)
453      new_beam_state.neg_likelihood += loss
454    return new_beam_state
455
456  def _calculate_score(self, beam_state, look_ahead_seq):
457    """Calculate negative log likelihoods for all possible state allocations
458       of a look ahead sequence, according to the current beam state.
459
460    Args:
461      beam_state: A BeamState object.
462      look_ahead_seq: Look ahead sequence, size: look_ahead*D.
463        look_ahead: number of step to look ahead in the beam search.
464        D: observation dimension
465
466    Returns:
467      beam_score_set: a set of scores for each possible state allocation.
468    """
469
470    look_ahead, _ = look_ahead_seq.shape
471    beam_num_clusters = len(beam_state.mean_set)
472    beam_score_set = float('inf') * np.ones(
473        beam_num_clusters + 1 + np.arange(look_ahead))
474    for cluster_seq, _ in np.ndenumerate(beam_score_set):
475      updated_beam_state = self._update_beam_state(beam_state,
476                                                   look_ahead_seq, cluster_seq)
477      beam_score_set[cluster_seq] = updated_beam_state.neg_likelihood
478    return beam_score_set
479
480  def predict_single(self, test_sequence, args):
481    """Predict labels for a single test sequence using UISRNN model.
482
483    Args:
484      test_sequence: the test observation sequence, which is 2-dim numpy array
485        of real numbers, of size `N * D`.
486
487        - `N`: length of one test utterance.
488        - `D` : observation dimension.
489
490        For example:
491      ```
492      test_sequence =
493      [[2.2 -1.0 3.0 5.6]    --> 1st entry of utterance 'iccc'
494       [0.5 1.8 -3.2 0.4]    --> 2nd entry of utterance 'iccc'
495       [-2.2 5.0 1.8 3.7]    --> 3rd entry of utterance 'iccc'
496       [-3.8 0.1 1.4 3.3]    --> 4th entry of utterance 'iccc'
497       [0.1 2.7 3.5 -1.7]]   --> 5th entry of utterance 'iccc'
498      ```
499        Here `N=5`, `D=4`.
500      args: Inference configurations. See `arguments.py` for details.
501
502    Returns:
503      predicted_cluster_id: predicted speaker id sequence, which is
504        an array of integers, of size `N`.
505        For example, `predicted_cluster_id = [0, 1, 0, 0, 1]`
506
507    Raises:
508      TypeError: If test_sequence is of wrong type.
509      ValueError: If test_sequence has wrong dimension.
510    """
511    # check type
512    if (not isinstance(test_sequence, np.ndarray) or
513        test_sequence.dtype != float):
514      raise TypeError('test_sequence should be a numpy array of float type.')
515    # check dimension
516    if test_sequence.ndim != 2:
517      raise ValueError('test_sequence must be 2-dim array.')
518    # check size
519    test_sequence_length, observation_dim = test_sequence.shape
520    if observation_dim != self.observation_dim:
521      raise ValueError('test_sequence does not match the dimension specified '
522                       'by args.observation_dim.')
523
524    self.rnn_model.eval()
525    test_sequence = np.tile(test_sequence, (args.test_iteration, 1))
526    test_sequence = autograd.Variable(
527        torch.from_numpy(test_sequence).float()).to(self.device)
528    # bookkeeping for beam search
529    beam_set = [BeamState()]
530    for num_iter in np.arange(0, args.test_iteration * test_sequence_length,
531                              args.look_ahead):
532      max_clusters = max([len(beam_state.mean_set) for beam_state in beam_set])
533      look_ahead_seq = test_sequence[num_iter:  num_iter + args.look_ahead, :]
534      look_ahead_seq_length = look_ahead_seq.shape[0]
535      score_set = float('inf') * np.ones(
536          np.append(
537              args.beam_size, max_clusters + 1 + np.arange(
538                  look_ahead_seq_length)))
539      for beam_rank, beam_state in enumerate(beam_set):
540        beam_score_set = self._calculate_score(beam_state, look_ahead_seq)
541        score_set[beam_rank, :] = np.pad(
542            beam_score_set,
543            np.tile([[0, max_clusters - len(beam_state.mean_set)]],
544                    (look_ahead_seq_length, 1)), 'constant',
545            constant_values=float('inf'))
546      # find top scores
547      score_ranked = np.sort(score_set, axis=None)
548      score_ranked[score_ranked == float('inf')] = 0
549      score_ranked = np.trim_zeros(score_ranked)
550      idx_ranked = np.argsort(score_set, axis=None)
551      updated_beam_set = []
552      for new_beam_rank in range(
553          np.min((len(score_ranked), args.beam_size))):
554        total_idx = np.unravel_index(idx_ranked[new_beam_rank],
555                                     score_set.shape)
556        prev_beam_rank = total_idx[0].item()
557        cluster_seq = total_idx[1:]
558        updated_beam_state = self._update_beam_state(
559            beam_set[prev_beam_rank], look_ahead_seq, cluster_seq)
560        updated_beam_set.append(updated_beam_state)
561      beam_set = updated_beam_set
562    predicted_cluster_id = beam_set[0].trace[-test_sequence_length:]
563    return predicted_cluster_id
564
565  def predict(self, test_sequences, args):
566    """Predict labels for a single or many test sequences using UISRNN model.
567
568    Args:
569      test_sequences: Either a list of test sequences, or a single test
570        sequence. Each test sequence is a 2-dim numpy array
571        of real numbers. See `predict_single()` for details.
572      args: Inference configurations. See `arguments.py` for details.
573
574    Returns:
575      predicted_cluster_ids: Predicted labels for test_sequences.
576
577        1. if test_sequences is a list, predicted_cluster_ids will be a list
578           of the same size, where each element being a 1-dim list of strings.
579        2. if test_sequences is a single sequence, predicted_cluster_ids will
580           be a 1-dim list of strings
581
582    Raises:
583      TypeError: If test_sequences is of wrong type.
584    """
585    # check type
586    if isinstance(test_sequences, np.ndarray):
587      return self.predict_single(test_sequences, args)
588    if isinstance(test_sequences, list):
589      return [self.predict_single(test_sequence, args)
590              for test_sequence in test_sequences]
591    raise TypeError('test_sequences should be either a list or numpy array.')

Unbounded Interleaved-State Recurrent Neural Networks.

UISRNN(args)
 84  def __init__(self, args):
 85    """Construct the UISRNN object.
 86
 87    Args:
 88      args: Model configurations. See `arguments.py` for details.
 89    """
 90    self.observation_dim = args.observation_dim
 91    self.device = torch.device(
 92        'cuda:0' if (torch.cuda.is_available() and args.enable_cuda) else 'cpu')
 93    self.rnn_model = CoreRNN(self.observation_dim, args.rnn_hidden_size,
 94                             args.rnn_depth, self.observation_dim,
 95                             args.rnn_dropout).to(self.device)
 96    self.rnn_init_hidden = nn.Parameter(
 97        torch.zeros(args.rnn_depth, 1, args.rnn_hidden_size).to(self.device))
 98    # booleans indicating which variables are trainable
 99    self.estimate_sigma2 = (args.sigma2 is None)
100    self.estimate_transition_bias = (args.transition_bias is None)
101    # initial values of variables
102    sigma2 = _INITIAL_SIGMA2_VALUE if self.estimate_sigma2 else args.sigma2
103    self.sigma2 = nn.Parameter(
104        sigma2 * torch.ones(self.observation_dim).to(self.device))
105    self.transition_bias = args.transition_bias
106    self.transition_bias_denominator = 0.0
107    self.crp_alpha = args.crp_alpha
108    self.logger = colortimelog.Logger(args.verbosity)

Construct the UISRNN object.

Args: args: Model configurations. See arguments.py for details.

observation_dim
device
rnn_model
rnn_init_hidden
estimate_sigma2
estimate_transition_bias
sigma2
transition_bias
transition_bias_denominator
crp_alpha
logger
def save(self, filepath):
136  def save(self, filepath):
137    """Save the model to a file.
138
139    Args:
140      filepath: the path of the file.
141    """
142    torch.save({
143        'rnn_state_dict': self.rnn_model.state_dict(),
144        'rnn_init_hidden': self.rnn_init_hidden.detach().cpu().numpy(),
145        'transition_bias': self.transition_bias,
146        'transition_bias_denominator': self.transition_bias_denominator,
147        'crp_alpha': self.crp_alpha,
148        'sigma2': self.sigma2.detach().cpu().numpy()}, filepath)

Save the model to a file.

Args: filepath: the path of the file.

def load(self, filepath):
150  def load(self, filepath):
151    """Load the model from a file.
152
153    Args:
154      filepath: the path of the file.
155    """
156    var_dict = torch.load(filepath)
157    self.rnn_model.load_state_dict(var_dict['rnn_state_dict'])
158    self.rnn_init_hidden = nn.Parameter(
159        torch.from_numpy(var_dict['rnn_init_hidden']).to(self.device))
160    self.transition_bias = float(var_dict['transition_bias'])
161    self.transition_bias_denominator = float(
162        var_dict['transition_bias_denominator'])
163    self.crp_alpha = float(var_dict['crp_alpha'])
164    self.sigma2 = nn.Parameter(
165        torch.from_numpy(var_dict['sigma2']).to(self.device))
166
167    self.logger.print(
168        3, 'Loaded model with transition_bias={}, crp_alpha={}, sigma2={}, '
169        'rnn_init_hidden={}'.format(
170            self.transition_bias, self.crp_alpha, var_dict['sigma2'],
171            var_dict['rnn_init_hidden']))

Load the model from a file.

Args: filepath: the path of the file.

def fit_concatenated(self, train_sequence, train_cluster_id, args):
173  def fit_concatenated(self, train_sequence, train_cluster_id, args):
174    """Fit UISRNN model to concatenated sequence and cluster_id.
175
176    Args:
177      train_sequence: the training observation sequence, which is a
178        2-dim numpy array of real numbers, of size `N * D`.
179
180        - `N`: summation of lengths of all utterances.
181        - `D`: observation dimension.
182
183        For example,
184      ```
185      train_sequence =
186      [[1.2 3.0 -4.1 6.0]    --> an entry of speaker #0 from utterance 'iaaa'
187       [0.8 -1.1 0.4 0.5]    --> an entry of speaker #1 from utterance 'iaaa'
188       [-0.2 1.0 3.8 5.7]    --> an entry of speaker #0 from utterance 'iaaa'
189       [3.8 -0.1 1.5 2.3]    --> an entry of speaker #0 from utterance 'ibbb'
190       [1.2 1.4 3.6 -2.7]]   --> an entry of speaker #0 from utterance 'ibbb'
191      ```
192        Here `N=5`, `D=4`.
193
194        We concatenate all training utterances into this single sequence.
195      train_cluster_id: the speaker id sequence, which is 1-dim list or
196        numpy array of strings, of size `N`.
197        For example,
198      ```
199      train_cluster_id =
200        ['iaaa_0', 'iaaa_1', 'iaaa_0', 'ibbb_0', 'ibbb_0']
201      ```
202        'iaaa_0' means the entry belongs to speaker #0 in utterance 'iaaa'.
203
204        Note that the order of entries within an utterance are preserved,
205        and all utterances are simply concatenated together.
206      args: Training configurations. See `arguments.py` for details.
207
208    Raises:
209      TypeError: If train_sequence or train_cluster_id is of wrong type.
210      ValueError: If train_sequence or train_cluster_id has wrong dimension.
211    """
212    # check type
213    if (not isinstance(train_sequence, np.ndarray) or
214        train_sequence.dtype != float):
215      raise TypeError('train_sequence should be a numpy array of float type.')
216    if isinstance(train_cluster_id, list):
217      train_cluster_id = np.array(train_cluster_id)
218    if (not isinstance(train_cluster_id, np.ndarray) or
219        not train_cluster_id.dtype.name.startswith(('str', 'unicode'))):
220      raise TypeError('train_cluster_id type be a numpy array of strings.')
221    # check dimension
222    if train_sequence.ndim != 2:
223      raise ValueError('train_sequence must be 2-dim array.')
224    if train_cluster_id.ndim != 1:
225      raise ValueError('train_cluster_id must be 1-dim array.')
226    # check length and size
227    train_total_length, observation_dim = train_sequence.shape
228    if observation_dim != self.observation_dim:
229      raise ValueError('train_sequence does not match the dimension specified '
230                       'by args.observation_dim.')
231    if train_total_length != len(train_cluster_id):
232      raise ValueError('train_sequence length is not equal to '
233                       'train_cluster_id length.')
234
235    self.rnn_model.train()
236    optimizer = self._get_optimizer(optimizer=args.optimizer,
237                                    learning_rate=args.learning_rate)
238
239    sub_sequences, seq_lengths = utils.resize_sequence(
240        sequence=train_sequence,
241        cluster_id=train_cluster_id,
242        num_permutations=args.num_permutations)
243
244    # For batch learning, pack the entire dataset.
245    if args.batch_size is None:
246      packed_train_sequence, rnn_truth = utils.pack_sequence(
247          sub_sequences,
248          seq_lengths,
249          args.batch_size,
250          self.observation_dim,
251          self.device)
252    train_loss = []
253    for num_iter in range(args.train_iteration):
254      optimizer.zero_grad()
255      # For online learning, pack a subset in each iteration.
256      if args.batch_size is not None:
257        packed_train_sequence, rnn_truth = utils.pack_sequence(
258            sub_sequences,
259            seq_lengths,
260            args.batch_size,
261            self.observation_dim,
262            self.device)
263      hidden = self.rnn_init_hidden.repeat(1, args.batch_size, 1)
264      mean, _ = self.rnn_model(packed_train_sequence, hidden)
265      # use mean to predict
266      mean = torch.cumsum(mean, dim=0)
267      mean_size = mean.size()
268      mean = torch.mm(
269          torch.diag(
270              1.0 / torch.arange(1, mean_size[0] + 1).float().to(self.device)),
271          mean.view(mean_size[0], -1))
272      mean = mean.view(mean_size)
273
274      # Likelihood part.
275      loss1 = loss_func.weighted_mse_loss(
276          input_tensor=(rnn_truth != 0).float() * mean[:-1, :, :],
277          target_tensor=rnn_truth,
278          weight=1 / (2 * self.sigma2))
279
280      # Sigma2 prior part.
281      weight = (((rnn_truth != 0).float() * mean[:-1, :, :] - rnn_truth)
282                ** 2).view(-1, observation_dim)
283      num_non_zero = torch.sum((weight != 0).float(), dim=0).squeeze()
284      loss2 = loss_func.sigma2_prior_loss(
285          num_non_zero, args.sigma_alpha, args.sigma_beta, self.sigma2)
286
287      # Regularization part.
288      loss3 = loss_func.regularization_loss(
289          self.rnn_model.parameters(), args.regularization_weight)
290
291      loss = loss1 + loss2 + loss3
292      loss.backward()
293      nn.utils.clip_grad_norm_(self.rnn_model.parameters(), args.grad_max_norm)
294      optimizer.step()
295      # avoid numerical issues
296      self.sigma2.data.clamp_(min=1e-6)
297
298      if (np.remainder(num_iter, 10) == 0 or
299          num_iter == args.train_iteration - 1):
300        self.logger.print(
301            2,
302            'Iter: {:d}  \t'
303            'Training Loss: {:.4f}    \n'
304            '    Negative Log Likelihood: {:.4f}\t'
305            'Sigma2 Prior: {:.4f}\t'
306            'Regularization: {:.4f}'.format(
307                num_iter,
308                float(loss.data),
309                float(loss1.data),
310                float(loss2.data),
311                float(loss3.data)))
312      train_loss.append(float(loss1.data))  # only save the likelihood part
313    self.logger.print(
314        1, 'Done training with {} iterations'.format(args.train_iteration))

Fit UISRNN model to concatenated sequence and cluster_id.

Args: train_sequence: the training observation sequence, which is a 2-dim numpy array of real numbers, of size N * D.

- `N`: summation of lengths of all utterances.
- `D`: observation dimension.

For example,

train_sequence =
[[1.2 3.0 -4.1 6.0]    --> an entry of speaker #0 from utterance 'iaaa'
 [0.8 -1.1 0.4 0.5]    --> an entry of speaker #1 from utterance 'iaaa'
 [-0.2 1.0 3.8 5.7]    --> an entry of speaker #0 from utterance 'iaaa'
 [3.8 -0.1 1.5 2.3]    --> an entry of speaker #0 from utterance 'ibbb'
 [1.2 1.4 3.6 -2.7]]   --> an entry of speaker #0 from utterance 'ibbb'

Here `N=5`, `D=4`.

We concatenate all training utterances into this single sequence.

train_cluster_id: the speaker id sequence, which is 1-dim list or numpy array of strings, of size N. For example,

train_cluster_id =
  ['iaaa_0', 'iaaa_1', 'iaaa_0', 'ibbb_0', 'ibbb_0']

'iaaa_0' means the entry belongs to speaker #0 in utterance 'iaaa'.

Note that the order of entries within an utterance are preserved,
and all utterances are simply concatenated together.

args: Training configurations. See arguments.py for details.

Raises: TypeError: If train_sequence or train_cluster_id is of wrong type. ValueError: If train_sequence or train_cluster_id has wrong dimension.

def fit(self, train_sequences, train_cluster_ids, args):
316  def fit(self, train_sequences, train_cluster_ids, args):
317    """Fit UISRNN model.
318
319    Args:
320      train_sequences: Either a list of training sequences, or a single
321        concatenated training sequence:
322
323        1. train_sequences is list, and each element is a 2-dim numpy array
324           of real numbers, of size: `length * D`.
325           The length varies among different sequences, but the D is the same.
326           In speaker diarization, each sequence is the sequence of speaker
327           embeddings of one utterance.
328        2. train_sequences is a single concatenated sequence, which is a
329           2-dim numpy array of real numbers. See `fit_concatenated()`
330           for more details.
331      train_cluster_ids: Ground truth labels for train_sequences:
332
333        1. if train_sequences is a list, this must also be a list of the same
334           size, each element being a 1-dim list or numpy array of strings.
335        2. if train_sequences is a single concatenated sequence, this
336           must also be the concatenated 1-dim list or numpy array of strings
337      args: Training configurations. See `arguments.py` for details.
338
339    Raises:
340      TypeError: If train_sequences or train_cluster_ids is of wrong type.
341    """
342    if isinstance(train_sequences, np.ndarray):
343      # train_sequences is already the concatenated sequence
344      if self.estimate_transition_bias:
345        # see issue #55: https://github.com/google/uis-rnn/issues/55
346        self.logger.print(
347            2,
348            'Warning: transition_bias cannot be correctly estimated from a '
349            'concatenated sequence; train_sequences will be treated as a '
350            'single sequence. This can lead to inaccurate estimation of '
351            'transition_bias. Please, consider estimating transition_bias '
352            'before concatenating the sequences and passing it as argument.')
353      train_sequences = [train_sequences]
354      train_cluster_ids = [train_cluster_ids]
355    elif isinstance(train_sequences, list):
356      # train_sequences is a list of un-concatenated sequences
357      # we will concatenate it later, after estimating transition_bias
358      pass
359    else:
360      raise TypeError('train_sequences must be a list or numpy.ndarray')
361
362    # estimate transition_bias
363    if self.estimate_transition_bias:
364      (transition_bias,
365       transition_bias_denominator) = utils.estimate_transition_bias(
366           train_cluster_ids)
367      # set or update transition_bias
368      if self.transition_bias is None:
369        self.transition_bias = transition_bias
370        self.transition_bias_denominator = transition_bias_denominator
371      else:
372        self.transition_bias = (
373            self.transition_bias * self.transition_bias_denominator +
374            transition_bias * transition_bias_denominator) / (
375                self.transition_bias_denominator + transition_bias_denominator)
376        self.transition_bias_denominator += transition_bias_denominator
377
378    # concatenate train_sequences
379    (concatenated_train_sequence,
380     concatenated_train_cluster_id) = utils.concatenate_training_data(
381         train_sequences,
382         train_cluster_ids,
383         args.enforce_cluster_id_uniqueness,
384         True)
385
386    self.fit_concatenated(
387        concatenated_train_sequence, concatenated_train_cluster_id, args)

Fit UISRNN model.

Args: train_sequences: Either a list of training sequences, or a single concatenated training sequence:

1. train_sequences is list, and each element is a 2-dim numpy array
   of real numbers, of size: `length * D`.
   The length varies among different sequences, but the D is the same.
   In speaker diarization, each sequence is the sequence of speaker
   embeddings of one utterance.
2. train_sequences is a single concatenated sequence, which is a
   2-dim numpy array of real numbers. See `fit_concatenated()`
   for more details.

train_cluster_ids: Ground truth labels for train_sequences:

1. if train_sequences is a list, this must also be a list of the same
   size, each element being a 1-dim list or numpy array of strings.
2. if train_sequences is a single concatenated sequence, this
   must also be the concatenated 1-dim list or numpy array of strings

args: Training configurations. See arguments.py for details.

Raises: TypeError: If train_sequences or train_cluster_ids is of wrong type.

def predict_single(self, test_sequence, args):
480  def predict_single(self, test_sequence, args):
481    """Predict labels for a single test sequence using UISRNN model.
482
483    Args:
484      test_sequence: the test observation sequence, which is 2-dim numpy array
485        of real numbers, of size `N * D`.
486
487        - `N`: length of one test utterance.
488        - `D` : observation dimension.
489
490        For example:
491      ```
492      test_sequence =
493      [[2.2 -1.0 3.0 5.6]    --> 1st entry of utterance 'iccc'
494       [0.5 1.8 -3.2 0.4]    --> 2nd entry of utterance 'iccc'
495       [-2.2 5.0 1.8 3.7]    --> 3rd entry of utterance 'iccc'
496       [-3.8 0.1 1.4 3.3]    --> 4th entry of utterance 'iccc'
497       [0.1 2.7 3.5 -1.7]]   --> 5th entry of utterance 'iccc'
498      ```
499        Here `N=5`, `D=4`.
500      args: Inference configurations. See `arguments.py` for details.
501
502    Returns:
503      predicted_cluster_id: predicted speaker id sequence, which is
504        an array of integers, of size `N`.
505        For example, `predicted_cluster_id = [0, 1, 0, 0, 1]`
506
507    Raises:
508      TypeError: If test_sequence is of wrong type.
509      ValueError: If test_sequence has wrong dimension.
510    """
511    # check type
512    if (not isinstance(test_sequence, np.ndarray) or
513        test_sequence.dtype != float):
514      raise TypeError('test_sequence should be a numpy array of float type.')
515    # check dimension
516    if test_sequence.ndim != 2:
517      raise ValueError('test_sequence must be 2-dim array.')
518    # check size
519    test_sequence_length, observation_dim = test_sequence.shape
520    if observation_dim != self.observation_dim:
521      raise ValueError('test_sequence does not match the dimension specified '
522                       'by args.observation_dim.')
523
524    self.rnn_model.eval()
525    test_sequence = np.tile(test_sequence, (args.test_iteration, 1))
526    test_sequence = autograd.Variable(
527        torch.from_numpy(test_sequence).float()).to(self.device)
528    # bookkeeping for beam search
529    beam_set = [BeamState()]
530    for num_iter in np.arange(0, args.test_iteration * test_sequence_length,
531                              args.look_ahead):
532      max_clusters = max([len(beam_state.mean_set) for beam_state in beam_set])
533      look_ahead_seq = test_sequence[num_iter:  num_iter + args.look_ahead, :]
534      look_ahead_seq_length = look_ahead_seq.shape[0]
535      score_set = float('inf') * np.ones(
536          np.append(
537              args.beam_size, max_clusters + 1 + np.arange(
538                  look_ahead_seq_length)))
539      for beam_rank, beam_state in enumerate(beam_set):
540        beam_score_set = self._calculate_score(beam_state, look_ahead_seq)
541        score_set[beam_rank, :] = np.pad(
542            beam_score_set,
543            np.tile([[0, max_clusters - len(beam_state.mean_set)]],
544                    (look_ahead_seq_length, 1)), 'constant',
545            constant_values=float('inf'))
546      # find top scores
547      score_ranked = np.sort(score_set, axis=None)
548      score_ranked[score_ranked == float('inf')] = 0
549      score_ranked = np.trim_zeros(score_ranked)
550      idx_ranked = np.argsort(score_set, axis=None)
551      updated_beam_set = []
552      for new_beam_rank in range(
553          np.min((len(score_ranked), args.beam_size))):
554        total_idx = np.unravel_index(idx_ranked[new_beam_rank],
555                                     score_set.shape)
556        prev_beam_rank = total_idx[0].item()
557        cluster_seq = total_idx[1:]
558        updated_beam_state = self._update_beam_state(
559            beam_set[prev_beam_rank], look_ahead_seq, cluster_seq)
560        updated_beam_set.append(updated_beam_state)
561      beam_set = updated_beam_set
562    predicted_cluster_id = beam_set[0].trace[-test_sequence_length:]
563    return predicted_cluster_id

Predict labels for a single test sequence using UISRNN model.

Args: test_sequence: the test observation sequence, which is 2-dim numpy array of real numbers, of size N * D.

- `N`: length of one test utterance.
- `D` : observation dimension.

For example:

test_sequence =
[[2.2 -1.0 3.0 5.6]    --> 1st entry of utterance 'iccc'
 [0.5 1.8 -3.2 0.4]    --> 2nd entry of utterance 'iccc'
 [-2.2 5.0 1.8 3.7]    --> 3rd entry of utterance 'iccc'
 [-3.8 0.1 1.4 3.3]    --> 4th entry of utterance 'iccc'
 [0.1 2.7 3.5 -1.7]]   --> 5th entry of utterance 'iccc'

Here `N=5`, `D=4`.

args: Inference configurations. See arguments.py for details.

Returns: predicted_cluster_id: predicted speaker id sequence, which is an array of integers, of size N. For example, predicted_cluster_id = [0, 1, 0, 0, 1]

Raises: TypeError: If test_sequence is of wrong type. ValueError: If test_sequence has wrong dimension.

def predict(self, test_sequences, args):
565  def predict(self, test_sequences, args):
566    """Predict labels for a single or many test sequences using UISRNN model.
567
568    Args:
569      test_sequences: Either a list of test sequences, or a single test
570        sequence. Each test sequence is a 2-dim numpy array
571        of real numbers. See `predict_single()` for details.
572      args: Inference configurations. See `arguments.py` for details.
573
574    Returns:
575      predicted_cluster_ids: Predicted labels for test_sequences.
576
577        1. if test_sequences is a list, predicted_cluster_ids will be a list
578           of the same size, where each element being a 1-dim list of strings.
579        2. if test_sequences is a single sequence, predicted_cluster_ids will
580           be a 1-dim list of strings
581
582    Raises:
583      TypeError: If test_sequences is of wrong type.
584    """
585    # check type
586    if isinstance(test_sequences, np.ndarray):
587      return self.predict_single(test_sequences, args)
588    if isinstance(test_sequences, list):
589      return [self.predict_single(test_sequence, args)
590              for test_sequence in test_sequences]
591    raise TypeError('test_sequences should be either a list or numpy array.')

Predict labels for a single or many test sequences using UISRNN model.

Args: test_sequences: Either a list of test sequences, or a single test sequence. Each test sequence is a 2-dim numpy array of real numbers. See predict_single() for details. args: Inference configurations. See arguments.py for details.

Returns: predicted_cluster_ids: Predicted labels for test_sequences.

1. if test_sequences is a list, predicted_cluster_ids will be a list
   of the same size, where each element being a 1-dim list of strings.
2. if test_sequences is a single sequence, predicted_cluster_ids will
   be a 1-dim list of strings

Raises: TypeError: If test_sequences is of wrong type.

def parallel_predict(model, test_sequences, args, num_processes=4):
594def parallel_predict(model, test_sequences, args, num_processes=4):
595  """Run prediction in parallel using torch.multiprocessing.
596
597  This is a beta feature. It makes prediction slower on CPU. But it's reported
598  that it makes prediction faster on GPU.
599
600  Args:
601    model: instance of UISRNN model
602    test_sequences: a list of test sequences, or a single test
603      sequence. Each test sequence is a 2-dim numpy array
604      of real numbers. See `predict_single()` for details.
605    args: Inference configurations. See `arguments.py` for details.
606    num_processes: number of parallel processes.
607
608  Returns:
609    a list of the same size as test_sequences, where each element
610    being a 1-dim list of strings.
611
612  Raises:
613      TypeError: If test_sequences is of wrong type.
614  """
615  if not isinstance(test_sequences, list):
616    raise TypeError('test_sequences must be a list.')
617  ctx = multiprocessing.get_context('forkserver')
618  model.rnn_model.share_memory()
619  pool = ctx.Pool(num_processes)
620  results = pool.map(
621      functools.partial(model.predict_single, args=args),
622      test_sequences)
623  pool.close()
624  return results

Run prediction in parallel using torch.multiprocessing.

This is a beta feature. It makes prediction slower on CPU. But it's reported that it makes prediction faster on GPU.

Args: model: instance of UISRNN model test_sequences: a list of test sequences, or a single test sequence. Each test sequence is a 2-dim numpy array of real numbers. See predict_single() for details. args: Inference configurations. See arguments.py for details. num_processes: number of parallel processes.

Returns: a list of the same size as test_sequences, where each element being a 1-dim list of strings.

Raises: TypeError: If test_sequences is of wrong type.