uisrnn.uisrnn

The UIS-RNN model.

  1# Copyright 2018 Google LLC
  2#
  3# Licensed under the Apache License, Version 2.0 (the "License");
  4# you may not use this file except in compliance with the License.
  5# You may obtain a copy of the License at
  6#
  7#     https://www.apache.org/licenses/LICENSE-2.0
  8#
  9# Unless required by applicable law or agreed to in writing, software
 10# distributed under the License is distributed on an "AS IS" BASIS,
 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12# See the License for the specific language governing permissions and
 13# limitations under the License.
 14"""The UIS-RNN model."""
 15
 16import functools
 17import numpy as np
 18import torch
 19from torch import autograd
 20from torch import multiprocessing
 21from torch import nn
 22from torch import optim
 23import torch.nn.functional as F
 24
 25from uisrnn import loss_func
 26from uisrnn import utils
 27
 28_INITIAL_SIGMA2_VALUE = 0.1
 29
 30
 31class CoreRNN(nn.Module):
 32  """The core Recurent Neural Network used by UIS-RNN."""
 33
 34  def __init__(self, input_dim, hidden_size, depth, observation_dim, dropout=0):
 35    super().__init__()
 36    self.hidden_size = hidden_size
 37    if depth >= 2:
 38      self.gru = nn.GRU(input_dim, hidden_size, depth, dropout=dropout)
 39    else:
 40      self.gru = nn.GRU(input_dim, hidden_size, depth)
 41    self.linear_mean1 = nn.Linear(hidden_size, hidden_size)
 42    self.linear_mean2 = nn.Linear(hidden_size, observation_dim)
 43
 44  def forward(self, input_seq, hidden=None):
 45    """The forward function of the module."""
 46    output_seq, hidden = self.gru(input_seq, hidden)
 47    if isinstance(output_seq, torch.nn.utils.rnn.PackedSequence):
 48      output_seq, _ = torch.nn.utils.rnn.pad_packed_sequence(
 49          output_seq, batch_first=False)
 50    mean = self.linear_mean2(F.relu(self.linear_mean1(output_seq)))
 51    return mean, hidden
 52
 53
 54class BeamState:
 55  """Structure that contains necessary states for beam search."""
 56
 57  def __init__(self, source=None):
 58    if not source:
 59      self.mean_set = []
 60      self.hidden_set = []
 61      self.neg_likelihood = 0
 62      self.trace = []
 63      self.block_counts = []
 64    else:
 65      self.mean_set = source.mean_set.copy()
 66      self.hidden_set = source.hidden_set.copy()
 67      self.trace = source.trace.copy()
 68      self.block_counts = source.block_counts.copy()
 69      self.neg_likelihood = source.neg_likelihood
 70
 71  def append(self, mean, hidden, cluster):
 72    """Append new item to the BeamState."""
 73    self.mean_set.append(mean.clone())
 74    self.hidden_set.append(hidden.clone())
 75    self.block_counts.append(1)
 76    self.trace.append(cluster)
 77
 78
 79class UISRNN:
 80  """Unbounded Interleaved-State Recurrent Neural Networks."""
 81
 82  def __init__(self, args):
 83    """Construct the UISRNN object.
 84
 85    Args:
 86      args: Model configurations. See `arguments.py` for details.
 87    """
 88    self.observation_dim = args.observation_dim
 89    self.device = torch.device(
 90        'cuda:0' if (torch.cuda.is_available() and args.enable_cuda) else 'cpu')
 91    self.rnn_model = CoreRNN(self.observation_dim, args.rnn_hidden_size,
 92                             args.rnn_depth, self.observation_dim,
 93                             args.rnn_dropout).to(self.device)
 94    self.rnn_init_hidden = nn.Parameter(
 95        torch.zeros(args.rnn_depth, 1, args.rnn_hidden_size).to(self.device))
 96    # booleans indicating which variables are trainable
 97    self.estimate_sigma2 = (args.sigma2 is None)
 98    self.estimate_transition_bias = (args.transition_bias is None)
 99    # initial values of variables
100    sigma2 = _INITIAL_SIGMA2_VALUE if self.estimate_sigma2 else args.sigma2
101    self.sigma2 = nn.Parameter(
102        sigma2 * torch.ones(self.observation_dim).to(self.device))
103    self.transition_bias = args.transition_bias
104    self.transition_bias_denominator = 0.0
105    self.crp_alpha = args.crp_alpha
106    self.logger = utils.Logger(args.verbosity)
107
108  def _get_optimizer(self, optimizer, learning_rate):
109    """Get optimizer for UISRNN.
110
111    Args:
112      optimizer: string - name of the optimizer.
113      learning_rate: - learning rate for the entire model.
114        We do not customize learning rate for separate parts.
115
116    Returns:
117      a pytorch "optim" object
118    """
119    params = [
120        {
121            'params': self.rnn_model.parameters()
122        },  # rnn parameters
123        {
124            'params': self.rnn_init_hidden
125        }  # rnn initial hidden state
126    ]
127    if self.estimate_sigma2:  # train sigma2
128      params.append({
129          'params': self.sigma2
130      })  # variance parameters
131    assert optimizer == 'adam', 'Only adam optimizer is supported.'
132    return optim.Adam(params, lr=learning_rate)
133
134  def save(self, filepath):
135    """Save the model to a file.
136
137    Args:
138      filepath: the path of the file.
139    """
140    torch.save({
141        'rnn_state_dict': self.rnn_model.state_dict(),
142        'rnn_init_hidden': self.rnn_init_hidden.detach().cpu().numpy(),
143        'transition_bias': self.transition_bias,
144        'transition_bias_denominator': self.transition_bias_denominator,
145        'crp_alpha': self.crp_alpha,
146        'sigma2': self.sigma2.detach().cpu().numpy()}, filepath)
147
148  def load(self, filepath):
149    """Load the model from a file.
150
151    Args:
152      filepath: the path of the file.
153    """
154    var_dict = torch.load(filepath)
155    self.rnn_model.load_state_dict(var_dict['rnn_state_dict'])
156    self.rnn_init_hidden = nn.Parameter(
157        torch.from_numpy(var_dict['rnn_init_hidden']).to(self.device))
158    self.transition_bias = float(var_dict['transition_bias'])
159    self.transition_bias_denominator = float(
160        var_dict['transition_bias_denominator'])
161    self.crp_alpha = float(var_dict['crp_alpha'])
162    self.sigma2 = nn.Parameter(
163        torch.from_numpy(var_dict['sigma2']).to(self.device))
164
165    self.logger.print(
166        3, 'Loaded model with transition_bias={}, crp_alpha={}, sigma2={}, '
167        'rnn_init_hidden={}'.format(
168            self.transition_bias, self.crp_alpha, var_dict['sigma2'],
169            var_dict['rnn_init_hidden']))
170
171  def fit_concatenated(self, train_sequence, train_cluster_id, args):
172    """Fit UISRNN model to concatenated sequence and cluster_id.
173
174    Args:
175      train_sequence: the training observation sequence, which is a
176        2-dim numpy array of real numbers, of size `N * D`.
177
178        - `N`: summation of lengths of all utterances.
179        - `D`: observation dimension.
180
181        For example,
182      ```
183      train_sequence =
184      [[1.2 3.0 -4.1 6.0]    --> an entry of speaker #0 from utterance 'iaaa'
185       [0.8 -1.1 0.4 0.5]    --> an entry of speaker #1 from utterance 'iaaa'
186       [-0.2 1.0 3.8 5.7]    --> an entry of speaker #0 from utterance 'iaaa'
187       [3.8 -0.1 1.5 2.3]    --> an entry of speaker #0 from utterance 'ibbb'
188       [1.2 1.4 3.6 -2.7]]   --> an entry of speaker #0 from utterance 'ibbb'
189      ```
190        Here `N=5`, `D=4`.
191
192        We concatenate all training utterances into this single sequence.
193      train_cluster_id: the speaker id sequence, which is 1-dim list or
194        numpy array of strings, of size `N`.
195        For example,
196      ```
197      train_cluster_id =
198        ['iaaa_0', 'iaaa_1', 'iaaa_0', 'ibbb_0', 'ibbb_0']
199      ```
200        'iaaa_0' means the entry belongs to speaker #0 in utterance 'iaaa'.
201
202        Note that the order of entries within an utterance are preserved,
203        and all utterances are simply concatenated together.
204      args: Training configurations. See `arguments.py` for details.
205
206    Raises:
207      TypeError: If train_sequence or train_cluster_id is of wrong type.
208      ValueError: If train_sequence or train_cluster_id has wrong dimension.
209    """
210    # check type
211    if (not isinstance(train_sequence, np.ndarray) or
212        train_sequence.dtype != float):
213      raise TypeError('train_sequence should be a numpy array of float type.')
214    if isinstance(train_cluster_id, list):
215      train_cluster_id = np.array(train_cluster_id)
216    if (not isinstance(train_cluster_id, np.ndarray) or
217        not train_cluster_id.dtype.name.startswith(('str', 'unicode'))):
218      raise TypeError('train_cluster_id type be a numpy array of strings.')
219    # check dimension
220    if train_sequence.ndim != 2:
221      raise ValueError('train_sequence must be 2-dim array.')
222    if train_cluster_id.ndim != 1:
223      raise ValueError('train_cluster_id must be 1-dim array.')
224    # check length and size
225    train_total_length, observation_dim = train_sequence.shape
226    if observation_dim != self.observation_dim:
227      raise ValueError('train_sequence does not match the dimension specified '
228                       'by args.observation_dim.')
229    if train_total_length != len(train_cluster_id):
230      raise ValueError('train_sequence length is not equal to '
231                       'train_cluster_id length.')
232
233    self.rnn_model.train()
234    optimizer = self._get_optimizer(optimizer=args.optimizer,
235                                    learning_rate=args.learning_rate)
236
237    sub_sequences, seq_lengths = utils.resize_sequence(
238        sequence=train_sequence,
239        cluster_id=train_cluster_id,
240        num_permutations=args.num_permutations)
241
242    # For batch learning, pack the entire dataset.
243    if args.batch_size is None:
244      packed_train_sequence, rnn_truth = utils.pack_sequence(
245          sub_sequences,
246          seq_lengths,
247          args.batch_size,
248          self.observation_dim,
249          self.device)
250    train_loss = []
251    for num_iter in range(args.train_iteration):
252      optimizer.zero_grad()
253      # For online learning, pack a subset in each iteration.
254      if args.batch_size is not None:
255        packed_train_sequence, rnn_truth = utils.pack_sequence(
256            sub_sequences,
257            seq_lengths,
258            args.batch_size,
259            self.observation_dim,
260            self.device)
261      hidden = self.rnn_init_hidden.repeat(1, args.batch_size, 1)
262      mean, _ = self.rnn_model(packed_train_sequence, hidden)
263      # use mean to predict
264      mean = torch.cumsum(mean, dim=0)
265      mean_size = mean.size()
266      mean = torch.mm(
267          torch.diag(
268              1.0 / torch.arange(1, mean_size[0] + 1).float().to(self.device)),
269          mean.view(mean_size[0], -1))
270      mean = mean.view(mean_size)
271
272      # Likelihood part.
273      loss1 = loss_func.weighted_mse_loss(
274          input_tensor=(rnn_truth != 0).float() * mean[:-1, :, :],
275          target_tensor=rnn_truth,
276          weight=1 / (2 * self.sigma2))
277
278      # Sigma2 prior part.
279      weight = (((rnn_truth != 0).float() * mean[:-1, :, :] - rnn_truth)
280                ** 2).view(-1, observation_dim)
281      num_non_zero = torch.sum((weight != 0).float(), dim=0).squeeze()
282      loss2 = loss_func.sigma2_prior_loss(
283          num_non_zero, args.sigma_alpha, args.sigma_beta, self.sigma2)
284
285      # Regularization part.
286      loss3 = loss_func.regularization_loss(
287          self.rnn_model.parameters(), args.regularization_weight)
288
289      loss = loss1 + loss2 + loss3
290      loss.backward()
291      nn.utils.clip_grad_norm_(self.rnn_model.parameters(), args.grad_max_norm)
292      optimizer.step()
293      # avoid numerical issues
294      self.sigma2.data.clamp_(min=1e-6)
295
296      if (np.remainder(num_iter, 10) == 0 or
297          num_iter == args.train_iteration - 1):
298        self.logger.print(
299            2,
300            'Iter: {:d}  \t'
301            'Training Loss: {:.4f}    \n'
302            '    Negative Log Likelihood: {:.4f}\t'
303            'Sigma2 Prior: {:.4f}\t'
304            'Regularization: {:.4f}'.format(
305                num_iter,
306                float(loss.data),
307                float(loss1.data),
308                float(loss2.data),
309                float(loss3.data)))
310      train_loss.append(float(loss1.data))  # only save the likelihood part
311    self.logger.print(
312        1, 'Done training with {} iterations'.format(args.train_iteration))
313
314  def fit(self, train_sequences, train_cluster_ids, args):
315    """Fit UISRNN model.
316
317    Args:
318      train_sequences: Either a list of training sequences, or a single
319        concatenated training sequence:
320
321        1. train_sequences is list, and each element is a 2-dim numpy array
322           of real numbers, of size: `length * D`.
323           The length varies among different sequences, but the D is the same.
324           In speaker diarization, each sequence is the sequence of speaker
325           embeddings of one utterance.
326        2. train_sequences is a single concatenated sequence, which is a
327           2-dim numpy array of real numbers. See `fit_concatenated()`
328           for more details.
329      train_cluster_ids: Ground truth labels for train_sequences:
330
331        1. if train_sequences is a list, this must also be a list of the same
332           size, each element being a 1-dim list or numpy array of strings.
333        2. if train_sequences is a single concatenated sequence, this
334           must also be the concatenated 1-dim list or numpy array of strings
335      args: Training configurations. See `arguments.py` for details.
336
337    Raises:
338      TypeError: If train_sequences or train_cluster_ids is of wrong type.
339    """
340    if isinstance(train_sequences, np.ndarray):
341      # train_sequences is already the concatenated sequence
342      if self.estimate_transition_bias:
343        # see issue #55: https://github.com/google/uis-rnn/issues/55
344        self.logger.print(
345            2,
346            'Warning: transition_bias cannot be correctly estimated from a '
347            'concatenated sequence; train_sequences will be treated as a '
348            'single sequence. This can lead to inaccurate estimation of '
349            'transition_bias. Please, consider estimating transition_bias '
350            'before concatenating the sequences and passing it as argument.')
351      train_sequences = [train_sequences]
352      train_cluster_ids = [train_cluster_ids]
353    elif isinstance(train_sequences, list):
354      # train_sequences is a list of un-concatenated sequences
355      # we will concatenate it later, after estimating transition_bias
356      pass
357    else:
358      raise TypeError('train_sequences must be a list or numpy.ndarray')
359
360    # estimate transition_bias
361    if self.estimate_transition_bias:
362      (transition_bias,
363       transition_bias_denominator) = utils.estimate_transition_bias(
364           train_cluster_ids)
365      # set or update transition_bias
366      if self.transition_bias is None:
367        self.transition_bias = transition_bias
368        self.transition_bias_denominator = transition_bias_denominator
369      else:
370        self.transition_bias = (
371            self.transition_bias * self.transition_bias_denominator +
372            transition_bias * transition_bias_denominator) / (
373                self.transition_bias_denominator + transition_bias_denominator)
374        self.transition_bias_denominator += transition_bias_denominator
375
376    # concatenate train_sequences
377    (concatenated_train_sequence,
378     concatenated_train_cluster_id) = utils.concatenate_training_data(
379         train_sequences,
380         train_cluster_ids,
381         args.enforce_cluster_id_uniqueness,
382         True)
383
384    self.fit_concatenated(
385        concatenated_train_sequence, concatenated_train_cluster_id, args)
386
387  def _update_beam_state(self, beam_state, look_ahead_seq, cluster_seq):
388    """Update a beam state given a look ahead sequence and known cluster
389    assignments.
390
391    Args:
392      beam_state: A BeamState object.
393      look_ahead_seq: Look ahead sequence, size: look_ahead*D.
394        look_ahead: number of step to look ahead in the beam search.
395        D: observation dimension
396      cluster_seq: Cluster assignment sequence for look_ahead_seq.
397
398    Returns:
399      new_beam_state: An updated BeamState object.
400    """
401
402    loss = 0
403    new_beam_state = BeamState(beam_state)
404    for sub_idx, cluster in enumerate(cluster_seq):
405      if cluster > len(new_beam_state.mean_set):  # invalid trace
406        new_beam_state.neg_likelihood = float('inf')
407        break
408      elif cluster < len(new_beam_state.mean_set):  # existing cluster
409        last_cluster = new_beam_state.trace[-1]
410        loss = loss_func.weighted_mse_loss(
411            input_tensor=torch.squeeze(new_beam_state.mean_set[cluster]),
412            target_tensor=look_ahead_seq[sub_idx, :],
413            weight=1 / (2 * self.sigma2)).cpu().detach().numpy()
414        if cluster == last_cluster:
415          loss -= np.log(1 - self.transition_bias)
416        else:
417          loss -= np.log(self.transition_bias) + np.log(
418              new_beam_state.block_counts[cluster]) - np.log(
419                  sum(new_beam_state.block_counts) + self.crp_alpha)
420        # update new mean and new hidden
421        mean, hidden = self.rnn_model(
422            look_ahead_seq[sub_idx, :].unsqueeze(0).unsqueeze(0),
423            new_beam_state.hidden_set[cluster])
424        new_beam_state.mean_set[cluster] = (new_beam_state.mean_set[cluster]*(
425            (np.array(new_beam_state.trace) == cluster).sum() -
426            1).astype(float) + mean.clone()) / (
427                np.array(new_beam_state.trace) == cluster).sum().astype(
428                    float)  # use mean to predict
429        new_beam_state.hidden_set[cluster] = hidden.clone()
430        if cluster != last_cluster:
431          new_beam_state.block_counts[cluster] += 1
432        new_beam_state.trace.append(cluster)
433      else:  # new cluster
434        init_input = autograd.Variable(
435            torch.zeros(self.observation_dim)
436        ).unsqueeze(0).unsqueeze(0).to(self.device)
437        mean, hidden = self.rnn_model(init_input,
438                                      self.rnn_init_hidden)
439        loss = loss_func.weighted_mse_loss(
440            input_tensor=torch.squeeze(mean),
441            target_tensor=look_ahead_seq[sub_idx, :],
442            weight=1 / (2 * self.sigma2)).cpu().detach().numpy()
443        loss -= np.log(self.transition_bias) + np.log(
444            self.crp_alpha) - np.log(
445                sum(new_beam_state.block_counts) + self.crp_alpha)
446        # update new min and new hidden
447        mean, hidden = self.rnn_model(
448            look_ahead_seq[sub_idx, :].unsqueeze(0).unsqueeze(0),
449            hidden)
450        new_beam_state.append(mean, hidden, cluster)
451      new_beam_state.neg_likelihood += loss
452    return new_beam_state
453
454  def _calculate_score(self, beam_state, look_ahead_seq):
455    """Calculate negative log likelihoods for all possible state allocations
456       of a look ahead sequence, according to the current beam state.
457
458    Args:
459      beam_state: A BeamState object.
460      look_ahead_seq: Look ahead sequence, size: look_ahead*D.
461        look_ahead: number of step to look ahead in the beam search.
462        D: observation dimension
463
464    Returns:
465      beam_score_set: a set of scores for each possible state allocation.
466    """
467
468    look_ahead, _ = look_ahead_seq.shape
469    beam_num_clusters = len(beam_state.mean_set)
470    beam_score_set = float('inf') * np.ones(
471        beam_num_clusters + 1 + np.arange(look_ahead))
472    for cluster_seq, _ in np.ndenumerate(beam_score_set):
473      updated_beam_state = self._update_beam_state(beam_state,
474                                                   look_ahead_seq, cluster_seq)
475      beam_score_set[cluster_seq] = updated_beam_state.neg_likelihood
476    return beam_score_set
477
478  def predict_single(self, test_sequence, args):
479    """Predict labels for a single test sequence using UISRNN model.
480
481    Args:
482      test_sequence: the test observation sequence, which is 2-dim numpy array
483        of real numbers, of size `N * D`.
484
485        - `N`: length of one test utterance.
486        - `D` : observation dimension.
487
488        For example:
489      ```
490      test_sequence =
491      [[2.2 -1.0 3.0 5.6]    --> 1st entry of utterance 'iccc'
492       [0.5 1.8 -3.2 0.4]    --> 2nd entry of utterance 'iccc'
493       [-2.2 5.0 1.8 3.7]    --> 3rd entry of utterance 'iccc'
494       [-3.8 0.1 1.4 3.3]    --> 4th entry of utterance 'iccc'
495       [0.1 2.7 3.5 -1.7]]   --> 5th entry of utterance 'iccc'
496      ```
497        Here `N=5`, `D=4`.
498      args: Inference configurations. See `arguments.py` for details.
499
500    Returns:
501      predicted_cluster_id: predicted speaker id sequence, which is
502        an array of integers, of size `N`.
503        For example, `predicted_cluster_id = [0, 1, 0, 0, 1]`
504
505    Raises:
506      TypeError: If test_sequence is of wrong type.
507      ValueError: If test_sequence has wrong dimension.
508    """
509    # check type
510    if (not isinstance(test_sequence, np.ndarray) or
511        test_sequence.dtype != float):
512      raise TypeError('test_sequence should be a numpy array of float type.')
513    # check dimension
514    if test_sequence.ndim != 2:
515      raise ValueError('test_sequence must be 2-dim array.')
516    # check size
517    test_sequence_length, observation_dim = test_sequence.shape
518    if observation_dim != self.observation_dim:
519      raise ValueError('test_sequence does not match the dimension specified '
520                       'by args.observation_dim.')
521
522    self.rnn_model.eval()
523    test_sequence = np.tile(test_sequence, (args.test_iteration, 1))
524    test_sequence = autograd.Variable(
525        torch.from_numpy(test_sequence).float()).to(self.device)
526    # bookkeeping for beam search
527    beam_set = [BeamState()]
528    for num_iter in np.arange(0, args.test_iteration * test_sequence_length,
529                              args.look_ahead):
530      max_clusters = max([len(beam_state.mean_set) for beam_state in beam_set])
531      look_ahead_seq = test_sequence[num_iter:  num_iter + args.look_ahead, :]
532      look_ahead_seq_length = look_ahead_seq.shape[0]
533      score_set = float('inf') * np.ones(
534          np.append(
535              args.beam_size, max_clusters + 1 + np.arange(
536                  look_ahead_seq_length)))
537      for beam_rank, beam_state in enumerate(beam_set):
538        beam_score_set = self._calculate_score(beam_state, look_ahead_seq)
539        score_set[beam_rank, :] = np.pad(
540            beam_score_set,
541            np.tile([[0, max_clusters - len(beam_state.mean_set)]],
542                    (look_ahead_seq_length, 1)), 'constant',
543            constant_values=float('inf'))
544      # find top scores
545      score_ranked = np.sort(score_set, axis=None)
546      score_ranked[score_ranked == float('inf')] = 0
547      score_ranked = np.trim_zeros(score_ranked)
548      idx_ranked = np.argsort(score_set, axis=None)
549      updated_beam_set = []
550      for new_beam_rank in range(
551          np.min((len(score_ranked), args.beam_size))):
552        total_idx = np.unravel_index(idx_ranked[new_beam_rank],
553                                     score_set.shape)
554        prev_beam_rank = total_idx[0].item()
555        cluster_seq = total_idx[1:]
556        updated_beam_state = self._update_beam_state(
557            beam_set[prev_beam_rank], look_ahead_seq, cluster_seq)
558        updated_beam_set.append(updated_beam_state)
559      beam_set = updated_beam_set
560    predicted_cluster_id = beam_set[0].trace[-test_sequence_length:]
561    return predicted_cluster_id
562
563  def predict(self, test_sequences, args):
564    """Predict labels for a single or many test sequences using UISRNN model.
565
566    Args:
567      test_sequences: Either a list of test sequences, or a single test
568        sequence. Each test sequence is a 2-dim numpy array
569        of real numbers. See `predict_single()` for details.
570      args: Inference configurations. See `arguments.py` for details.
571
572    Returns:
573      predicted_cluster_ids: Predicted labels for test_sequences.
574
575        1. if test_sequences is a list, predicted_cluster_ids will be a list
576           of the same size, where each element being a 1-dim list of strings.
577        2. if test_sequences is a single sequence, predicted_cluster_ids will
578           be a 1-dim list of strings
579
580    Raises:
581      TypeError: If test_sequences is of wrong type.
582    """
583    # check type
584    if isinstance(test_sequences, np.ndarray):
585      return self.predict_single(test_sequences, args)
586    if isinstance(test_sequences, list):
587      return [self.predict_single(test_sequence, args)
588              for test_sequence in test_sequences]
589    raise TypeError('test_sequences should be either a list or numpy array.')
590
591
592def parallel_predict(model, test_sequences, args, num_processes=4):
593  """Run prediction in parallel using torch.multiprocessing.
594
595  This is a beta feature. It makes prediction slower on CPU. But it's reported
596  that it makes prediction faster on GPU.
597
598  Args:
599    model: instance of UISRNN model
600    test_sequences: a list of test sequences, or a single test
601      sequence. Each test sequence is a 2-dim numpy array
602      of real numbers. See `predict_single()` for details.
603    args: Inference configurations. See `arguments.py` for details.
604    num_processes: number of parallel processes.
605
606  Returns:
607    a list of the same size as test_sequences, where each element
608    being a 1-dim list of strings.
609
610  Raises:
611      TypeError: If test_sequences is of wrong type.
612  """
613  if not isinstance(test_sequences, list):
614    raise TypeError('test_sequences must be a list.')
615  ctx = multiprocessing.get_context('forkserver')
616  model.rnn_model.share_memory()
617  pool = ctx.Pool(num_processes)
618  results = pool.map(
619      functools.partial(model.predict_single, args=args),
620      test_sequences)
621  pool.close()
622  return results
class CoreRNN(torch.nn.modules.module.Module):
32class CoreRNN(nn.Module):
33  """The core Recurent Neural Network used by UIS-RNN."""
34
35  def __init__(self, input_dim, hidden_size, depth, observation_dim, dropout=0):
36    super().__init__()
37    self.hidden_size = hidden_size
38    if depth >= 2:
39      self.gru = nn.GRU(input_dim, hidden_size, depth, dropout=dropout)
40    else:
41      self.gru = nn.GRU(input_dim, hidden_size, depth)
42    self.linear_mean1 = nn.Linear(hidden_size, hidden_size)
43    self.linear_mean2 = nn.Linear(hidden_size, observation_dim)
44
45  def forward(self, input_seq, hidden=None):
46    """The forward function of the module."""
47    output_seq, hidden = self.gru(input_seq, hidden)
48    if isinstance(output_seq, torch.nn.utils.rnn.PackedSequence):
49      output_seq, _ = torch.nn.utils.rnn.pad_packed_sequence(
50          output_seq, batch_first=False)
51    mean = self.linear_mean2(F.relu(self.linear_mean1(output_seq)))
52    return mean, hidden

The core Recurent Neural Network used by UIS-RNN.

CoreRNN(input_dim, hidden_size, depth, observation_dim, dropout=0)
35  def __init__(self, input_dim, hidden_size, depth, observation_dim, dropout=0):
36    super().__init__()
37    self.hidden_size = hidden_size
38    if depth >= 2:
39      self.gru = nn.GRU(input_dim, hidden_size, depth, dropout=dropout)
40    else:
41      self.gru = nn.GRU(input_dim, hidden_size, depth)
42    self.linear_mean1 = nn.Linear(hidden_size, hidden_size)
43    self.linear_mean2 = nn.Linear(hidden_size, observation_dim)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

hidden_size
linear_mean1
linear_mean2
def forward(self, input_seq, hidden=None):
45  def forward(self, input_seq, hidden=None):
46    """The forward function of the module."""
47    output_seq, hidden = self.gru(input_seq, hidden)
48    if isinstance(output_seq, torch.nn.utils.rnn.PackedSequence):
49      output_seq, _ = torch.nn.utils.rnn.pad_packed_sequence(
50          output_seq, batch_first=False)
51    mean = self.linear_mean2(F.relu(self.linear_mean1(output_seq)))
52    return mean, hidden

The forward function of the module.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
class BeamState:
55class BeamState:
56  """Structure that contains necessary states for beam search."""
57
58  def __init__(self, source=None):
59    if not source:
60      self.mean_set = []
61      self.hidden_set = []
62      self.neg_likelihood = 0
63      self.trace = []
64      self.block_counts = []
65    else:
66      self.mean_set = source.mean_set.copy()
67      self.hidden_set = source.hidden_set.copy()
68      self.trace = source.trace.copy()
69      self.block_counts = source.block_counts.copy()
70      self.neg_likelihood = source.neg_likelihood
71
72  def append(self, mean, hidden, cluster):
73    """Append new item to the BeamState."""
74    self.mean_set.append(mean.clone())
75    self.hidden_set.append(hidden.clone())
76    self.block_counts.append(1)
77    self.trace.append(cluster)

Structure that contains necessary states for beam search.

BeamState(source=None)
58  def __init__(self, source=None):
59    if not source:
60      self.mean_set = []
61      self.hidden_set = []
62      self.neg_likelihood = 0
63      self.trace = []
64      self.block_counts = []
65    else:
66      self.mean_set = source.mean_set.copy()
67      self.hidden_set = source.hidden_set.copy()
68      self.trace = source.trace.copy()
69      self.block_counts = source.block_counts.copy()
70      self.neg_likelihood = source.neg_likelihood
def append(self, mean, hidden, cluster):
72  def append(self, mean, hidden, cluster):
73    """Append new item to the BeamState."""
74    self.mean_set.append(mean.clone())
75    self.hidden_set.append(hidden.clone())
76    self.block_counts.append(1)
77    self.trace.append(cluster)

Append new item to the BeamState.

class UISRNN:
 80class UISRNN:
 81  """Unbounded Interleaved-State Recurrent Neural Networks."""
 82
 83  def __init__(self, args):
 84    """Construct the UISRNN object.
 85
 86    Args:
 87      args: Model configurations. See `arguments.py` for details.
 88    """
 89    self.observation_dim = args.observation_dim
 90    self.device = torch.device(
 91        'cuda:0' if (torch.cuda.is_available() and args.enable_cuda) else 'cpu')
 92    self.rnn_model = CoreRNN(self.observation_dim, args.rnn_hidden_size,
 93                             args.rnn_depth, self.observation_dim,
 94                             args.rnn_dropout).to(self.device)
 95    self.rnn_init_hidden = nn.Parameter(
 96        torch.zeros(args.rnn_depth, 1, args.rnn_hidden_size).to(self.device))
 97    # booleans indicating which variables are trainable
 98    self.estimate_sigma2 = (args.sigma2 is None)
 99    self.estimate_transition_bias = (args.transition_bias is None)
100    # initial values of variables
101    sigma2 = _INITIAL_SIGMA2_VALUE if self.estimate_sigma2 else args.sigma2
102    self.sigma2 = nn.Parameter(
103        sigma2 * torch.ones(self.observation_dim).to(self.device))
104    self.transition_bias = args.transition_bias
105    self.transition_bias_denominator = 0.0
106    self.crp_alpha = args.crp_alpha
107    self.logger = utils.Logger(args.verbosity)
108
109  def _get_optimizer(self, optimizer, learning_rate):
110    """Get optimizer for UISRNN.
111
112    Args:
113      optimizer: string - name of the optimizer.
114      learning_rate: - learning rate for the entire model.
115        We do not customize learning rate for separate parts.
116
117    Returns:
118      a pytorch "optim" object
119    """
120    params = [
121        {
122            'params': self.rnn_model.parameters()
123        },  # rnn parameters
124        {
125            'params': self.rnn_init_hidden
126        }  # rnn initial hidden state
127    ]
128    if self.estimate_sigma2:  # train sigma2
129      params.append({
130          'params': self.sigma2
131      })  # variance parameters
132    assert optimizer == 'adam', 'Only adam optimizer is supported.'
133    return optim.Adam(params, lr=learning_rate)
134
135  def save(self, filepath):
136    """Save the model to a file.
137
138    Args:
139      filepath: the path of the file.
140    """
141    torch.save({
142        'rnn_state_dict': self.rnn_model.state_dict(),
143        'rnn_init_hidden': self.rnn_init_hidden.detach().cpu().numpy(),
144        'transition_bias': self.transition_bias,
145        'transition_bias_denominator': self.transition_bias_denominator,
146        'crp_alpha': self.crp_alpha,
147        'sigma2': self.sigma2.detach().cpu().numpy()}, filepath)
148
149  def load(self, filepath):
150    """Load the model from a file.
151
152    Args:
153      filepath: the path of the file.
154    """
155    var_dict = torch.load(filepath)
156    self.rnn_model.load_state_dict(var_dict['rnn_state_dict'])
157    self.rnn_init_hidden = nn.Parameter(
158        torch.from_numpy(var_dict['rnn_init_hidden']).to(self.device))
159    self.transition_bias = float(var_dict['transition_bias'])
160    self.transition_bias_denominator = float(
161        var_dict['transition_bias_denominator'])
162    self.crp_alpha = float(var_dict['crp_alpha'])
163    self.sigma2 = nn.Parameter(
164        torch.from_numpy(var_dict['sigma2']).to(self.device))
165
166    self.logger.print(
167        3, 'Loaded model with transition_bias={}, crp_alpha={}, sigma2={}, '
168        'rnn_init_hidden={}'.format(
169            self.transition_bias, self.crp_alpha, var_dict['sigma2'],
170            var_dict['rnn_init_hidden']))
171
172  def fit_concatenated(self, train_sequence, train_cluster_id, args):
173    """Fit UISRNN model to concatenated sequence and cluster_id.
174
175    Args:
176      train_sequence: the training observation sequence, which is a
177        2-dim numpy array of real numbers, of size `N * D`.
178
179        - `N`: summation of lengths of all utterances.
180        - `D`: observation dimension.
181
182        For example,
183      ```
184      train_sequence =
185      [[1.2 3.0 -4.1 6.0]    --> an entry of speaker #0 from utterance 'iaaa'
186       [0.8 -1.1 0.4 0.5]    --> an entry of speaker #1 from utterance 'iaaa'
187       [-0.2 1.0 3.8 5.7]    --> an entry of speaker #0 from utterance 'iaaa'
188       [3.8 -0.1 1.5 2.3]    --> an entry of speaker #0 from utterance 'ibbb'
189       [1.2 1.4 3.6 -2.7]]   --> an entry of speaker #0 from utterance 'ibbb'
190      ```
191        Here `N=5`, `D=4`.
192
193        We concatenate all training utterances into this single sequence.
194      train_cluster_id: the speaker id sequence, which is 1-dim list or
195        numpy array of strings, of size `N`.
196        For example,
197      ```
198      train_cluster_id =
199        ['iaaa_0', 'iaaa_1', 'iaaa_0', 'ibbb_0', 'ibbb_0']
200      ```
201        'iaaa_0' means the entry belongs to speaker #0 in utterance 'iaaa'.
202
203        Note that the order of entries within an utterance are preserved,
204        and all utterances are simply concatenated together.
205      args: Training configurations. See `arguments.py` for details.
206
207    Raises:
208      TypeError: If train_sequence or train_cluster_id is of wrong type.
209      ValueError: If train_sequence or train_cluster_id has wrong dimension.
210    """
211    # check type
212    if (not isinstance(train_sequence, np.ndarray) or
213        train_sequence.dtype != float):
214      raise TypeError('train_sequence should be a numpy array of float type.')
215    if isinstance(train_cluster_id, list):
216      train_cluster_id = np.array(train_cluster_id)
217    if (not isinstance(train_cluster_id, np.ndarray) or
218        not train_cluster_id.dtype.name.startswith(('str', 'unicode'))):
219      raise TypeError('train_cluster_id type be a numpy array of strings.')
220    # check dimension
221    if train_sequence.ndim != 2:
222      raise ValueError('train_sequence must be 2-dim array.')
223    if train_cluster_id.ndim != 1:
224      raise ValueError('train_cluster_id must be 1-dim array.')
225    # check length and size
226    train_total_length, observation_dim = train_sequence.shape
227    if observation_dim != self.observation_dim:
228      raise ValueError('train_sequence does not match the dimension specified '
229                       'by args.observation_dim.')
230    if train_total_length != len(train_cluster_id):
231      raise ValueError('train_sequence length is not equal to '
232                       'train_cluster_id length.')
233
234    self.rnn_model.train()
235    optimizer = self._get_optimizer(optimizer=args.optimizer,
236                                    learning_rate=args.learning_rate)
237
238    sub_sequences, seq_lengths = utils.resize_sequence(
239        sequence=train_sequence,
240        cluster_id=train_cluster_id,
241        num_permutations=args.num_permutations)
242
243    # For batch learning, pack the entire dataset.
244    if args.batch_size is None:
245      packed_train_sequence, rnn_truth = utils.pack_sequence(
246          sub_sequences,
247          seq_lengths,
248          args.batch_size,
249          self.observation_dim,
250          self.device)
251    train_loss = []
252    for num_iter in range(args.train_iteration):
253      optimizer.zero_grad()
254      # For online learning, pack a subset in each iteration.
255      if args.batch_size is not None:
256        packed_train_sequence, rnn_truth = utils.pack_sequence(
257            sub_sequences,
258            seq_lengths,
259            args.batch_size,
260            self.observation_dim,
261            self.device)
262      hidden = self.rnn_init_hidden.repeat(1, args.batch_size, 1)
263      mean, _ = self.rnn_model(packed_train_sequence, hidden)
264      # use mean to predict
265      mean = torch.cumsum(mean, dim=0)
266      mean_size = mean.size()
267      mean = torch.mm(
268          torch.diag(
269              1.0 / torch.arange(1, mean_size[0] + 1).float().to(self.device)),
270          mean.view(mean_size[0], -1))
271      mean = mean.view(mean_size)
272
273      # Likelihood part.
274      loss1 = loss_func.weighted_mse_loss(
275          input_tensor=(rnn_truth != 0).float() * mean[:-1, :, :],
276          target_tensor=rnn_truth,
277          weight=1 / (2 * self.sigma2))
278
279      # Sigma2 prior part.
280      weight = (((rnn_truth != 0).float() * mean[:-1, :, :] - rnn_truth)
281                ** 2).view(-1, observation_dim)
282      num_non_zero = torch.sum((weight != 0).float(), dim=0).squeeze()
283      loss2 = loss_func.sigma2_prior_loss(
284          num_non_zero, args.sigma_alpha, args.sigma_beta, self.sigma2)
285
286      # Regularization part.
287      loss3 = loss_func.regularization_loss(
288          self.rnn_model.parameters(), args.regularization_weight)
289
290      loss = loss1 + loss2 + loss3
291      loss.backward()
292      nn.utils.clip_grad_norm_(self.rnn_model.parameters(), args.grad_max_norm)
293      optimizer.step()
294      # avoid numerical issues
295      self.sigma2.data.clamp_(min=1e-6)
296
297      if (np.remainder(num_iter, 10) == 0 or
298          num_iter == args.train_iteration - 1):
299        self.logger.print(
300            2,
301            'Iter: {:d}  \t'
302            'Training Loss: {:.4f}    \n'
303            '    Negative Log Likelihood: {:.4f}\t'
304            'Sigma2 Prior: {:.4f}\t'
305            'Regularization: {:.4f}'.format(
306                num_iter,
307                float(loss.data),
308                float(loss1.data),
309                float(loss2.data),
310                float(loss3.data)))
311      train_loss.append(float(loss1.data))  # only save the likelihood part
312    self.logger.print(
313        1, 'Done training with {} iterations'.format(args.train_iteration))
314
315  def fit(self, train_sequences, train_cluster_ids, args):
316    """Fit UISRNN model.
317
318    Args:
319      train_sequences: Either a list of training sequences, or a single
320        concatenated training sequence:
321
322        1. train_sequences is list, and each element is a 2-dim numpy array
323           of real numbers, of size: `length * D`.
324           The length varies among different sequences, but the D is the same.
325           In speaker diarization, each sequence is the sequence of speaker
326           embeddings of one utterance.
327        2. train_sequences is a single concatenated sequence, which is a
328           2-dim numpy array of real numbers. See `fit_concatenated()`
329           for more details.
330      train_cluster_ids: Ground truth labels for train_sequences:
331
332        1. if train_sequences is a list, this must also be a list of the same
333           size, each element being a 1-dim list or numpy array of strings.
334        2. if train_sequences is a single concatenated sequence, this
335           must also be the concatenated 1-dim list or numpy array of strings
336      args: Training configurations. See `arguments.py` for details.
337
338    Raises:
339      TypeError: If train_sequences or train_cluster_ids is of wrong type.
340    """
341    if isinstance(train_sequences, np.ndarray):
342      # train_sequences is already the concatenated sequence
343      if self.estimate_transition_bias:
344        # see issue #55: https://github.com/google/uis-rnn/issues/55
345        self.logger.print(
346            2,
347            'Warning: transition_bias cannot be correctly estimated from a '
348            'concatenated sequence; train_sequences will be treated as a '
349            'single sequence. This can lead to inaccurate estimation of '
350            'transition_bias. Please, consider estimating transition_bias '
351            'before concatenating the sequences and passing it as argument.')
352      train_sequences = [train_sequences]
353      train_cluster_ids = [train_cluster_ids]
354    elif isinstance(train_sequences, list):
355      # train_sequences is a list of un-concatenated sequences
356      # we will concatenate it later, after estimating transition_bias
357      pass
358    else:
359      raise TypeError('train_sequences must be a list or numpy.ndarray')
360
361    # estimate transition_bias
362    if self.estimate_transition_bias:
363      (transition_bias,
364       transition_bias_denominator) = utils.estimate_transition_bias(
365           train_cluster_ids)
366      # set or update transition_bias
367      if self.transition_bias is None:
368        self.transition_bias = transition_bias
369        self.transition_bias_denominator = transition_bias_denominator
370      else:
371        self.transition_bias = (
372            self.transition_bias * self.transition_bias_denominator +
373            transition_bias * transition_bias_denominator) / (
374                self.transition_bias_denominator + transition_bias_denominator)
375        self.transition_bias_denominator += transition_bias_denominator
376
377    # concatenate train_sequences
378    (concatenated_train_sequence,
379     concatenated_train_cluster_id) = utils.concatenate_training_data(
380         train_sequences,
381         train_cluster_ids,
382         args.enforce_cluster_id_uniqueness,
383         True)
384
385    self.fit_concatenated(
386        concatenated_train_sequence, concatenated_train_cluster_id, args)
387
388  def _update_beam_state(self, beam_state, look_ahead_seq, cluster_seq):
389    """Update a beam state given a look ahead sequence and known cluster
390    assignments.
391
392    Args:
393      beam_state: A BeamState object.
394      look_ahead_seq: Look ahead sequence, size: look_ahead*D.
395        look_ahead: number of step to look ahead in the beam search.
396        D: observation dimension
397      cluster_seq: Cluster assignment sequence for look_ahead_seq.
398
399    Returns:
400      new_beam_state: An updated BeamState object.
401    """
402
403    loss = 0
404    new_beam_state = BeamState(beam_state)
405    for sub_idx, cluster in enumerate(cluster_seq):
406      if cluster > len(new_beam_state.mean_set):  # invalid trace
407        new_beam_state.neg_likelihood = float('inf')
408        break
409      elif cluster < len(new_beam_state.mean_set):  # existing cluster
410        last_cluster = new_beam_state.trace[-1]
411        loss = loss_func.weighted_mse_loss(
412            input_tensor=torch.squeeze(new_beam_state.mean_set[cluster]),
413            target_tensor=look_ahead_seq[sub_idx, :],
414            weight=1 / (2 * self.sigma2)).cpu().detach().numpy()
415        if cluster == last_cluster:
416          loss -= np.log(1 - self.transition_bias)
417        else:
418          loss -= np.log(self.transition_bias) + np.log(
419              new_beam_state.block_counts[cluster]) - np.log(
420                  sum(new_beam_state.block_counts) + self.crp_alpha)
421        # update new mean and new hidden
422        mean, hidden = self.rnn_model(
423            look_ahead_seq[sub_idx, :].unsqueeze(0).unsqueeze(0),
424            new_beam_state.hidden_set[cluster])
425        new_beam_state.mean_set[cluster] = (new_beam_state.mean_set[cluster]*(
426            (np.array(new_beam_state.trace) == cluster).sum() -
427            1).astype(float) + mean.clone()) / (
428                np.array(new_beam_state.trace) == cluster).sum().astype(
429                    float)  # use mean to predict
430        new_beam_state.hidden_set[cluster] = hidden.clone()
431        if cluster != last_cluster:
432          new_beam_state.block_counts[cluster] += 1
433        new_beam_state.trace.append(cluster)
434      else:  # new cluster
435        init_input = autograd.Variable(
436            torch.zeros(self.observation_dim)
437        ).unsqueeze(0).unsqueeze(0).to(self.device)
438        mean, hidden = self.rnn_model(init_input,
439                                      self.rnn_init_hidden)
440        loss = loss_func.weighted_mse_loss(
441            input_tensor=torch.squeeze(mean),
442            target_tensor=look_ahead_seq[sub_idx, :],
443            weight=1 / (2 * self.sigma2)).cpu().detach().numpy()
444        loss -= np.log(self.transition_bias) + np.log(
445            self.crp_alpha) - np.log(
446                sum(new_beam_state.block_counts) + self.crp_alpha)
447        # update new min and new hidden
448        mean, hidden = self.rnn_model(
449            look_ahead_seq[sub_idx, :].unsqueeze(0).unsqueeze(0),
450            hidden)
451        new_beam_state.append(mean, hidden, cluster)
452      new_beam_state.neg_likelihood += loss
453    return new_beam_state
454
455  def _calculate_score(self, beam_state, look_ahead_seq):
456    """Calculate negative log likelihoods for all possible state allocations
457       of a look ahead sequence, according to the current beam state.
458
459    Args:
460      beam_state: A BeamState object.
461      look_ahead_seq: Look ahead sequence, size: look_ahead*D.
462        look_ahead: number of step to look ahead in the beam search.
463        D: observation dimension
464
465    Returns:
466      beam_score_set: a set of scores for each possible state allocation.
467    """
468
469    look_ahead, _ = look_ahead_seq.shape
470    beam_num_clusters = len(beam_state.mean_set)
471    beam_score_set = float('inf') * np.ones(
472        beam_num_clusters + 1 + np.arange(look_ahead))
473    for cluster_seq, _ in np.ndenumerate(beam_score_set):
474      updated_beam_state = self._update_beam_state(beam_state,
475                                                   look_ahead_seq, cluster_seq)
476      beam_score_set[cluster_seq] = updated_beam_state.neg_likelihood
477    return beam_score_set
478
479  def predict_single(self, test_sequence, args):
480    """Predict labels for a single test sequence using UISRNN model.
481
482    Args:
483      test_sequence: the test observation sequence, which is 2-dim numpy array
484        of real numbers, of size `N * D`.
485
486        - `N`: length of one test utterance.
487        - `D` : observation dimension.
488
489        For example:
490      ```
491      test_sequence =
492      [[2.2 -1.0 3.0 5.6]    --> 1st entry of utterance 'iccc'
493       [0.5 1.8 -3.2 0.4]    --> 2nd entry of utterance 'iccc'
494       [-2.2 5.0 1.8 3.7]    --> 3rd entry of utterance 'iccc'
495       [-3.8 0.1 1.4 3.3]    --> 4th entry of utterance 'iccc'
496       [0.1 2.7 3.5 -1.7]]   --> 5th entry of utterance 'iccc'
497      ```
498        Here `N=5`, `D=4`.
499      args: Inference configurations. See `arguments.py` for details.
500
501    Returns:
502      predicted_cluster_id: predicted speaker id sequence, which is
503        an array of integers, of size `N`.
504        For example, `predicted_cluster_id = [0, 1, 0, 0, 1]`
505
506    Raises:
507      TypeError: If test_sequence is of wrong type.
508      ValueError: If test_sequence has wrong dimension.
509    """
510    # check type
511    if (not isinstance(test_sequence, np.ndarray) or
512        test_sequence.dtype != float):
513      raise TypeError('test_sequence should be a numpy array of float type.')
514    # check dimension
515    if test_sequence.ndim != 2:
516      raise ValueError('test_sequence must be 2-dim array.')
517    # check size
518    test_sequence_length, observation_dim = test_sequence.shape
519    if observation_dim != self.observation_dim:
520      raise ValueError('test_sequence does not match the dimension specified '
521                       'by args.observation_dim.')
522
523    self.rnn_model.eval()
524    test_sequence = np.tile(test_sequence, (args.test_iteration, 1))
525    test_sequence = autograd.Variable(
526        torch.from_numpy(test_sequence).float()).to(self.device)
527    # bookkeeping for beam search
528    beam_set = [BeamState()]
529    for num_iter in np.arange(0, args.test_iteration * test_sequence_length,
530                              args.look_ahead):
531      max_clusters = max([len(beam_state.mean_set) for beam_state in beam_set])
532      look_ahead_seq = test_sequence[num_iter:  num_iter + args.look_ahead, :]
533      look_ahead_seq_length = look_ahead_seq.shape[0]
534      score_set = float('inf') * np.ones(
535          np.append(
536              args.beam_size, max_clusters + 1 + np.arange(
537                  look_ahead_seq_length)))
538      for beam_rank, beam_state in enumerate(beam_set):
539        beam_score_set = self._calculate_score(beam_state, look_ahead_seq)
540        score_set[beam_rank, :] = np.pad(
541            beam_score_set,
542            np.tile([[0, max_clusters - len(beam_state.mean_set)]],
543                    (look_ahead_seq_length, 1)), 'constant',
544            constant_values=float('inf'))
545      # find top scores
546      score_ranked = np.sort(score_set, axis=None)
547      score_ranked[score_ranked == float('inf')] = 0
548      score_ranked = np.trim_zeros(score_ranked)
549      idx_ranked = np.argsort(score_set, axis=None)
550      updated_beam_set = []
551      for new_beam_rank in range(
552          np.min((len(score_ranked), args.beam_size))):
553        total_idx = np.unravel_index(idx_ranked[new_beam_rank],
554                                     score_set.shape)
555        prev_beam_rank = total_idx[0].item()
556        cluster_seq = total_idx[1:]
557        updated_beam_state = self._update_beam_state(
558            beam_set[prev_beam_rank], look_ahead_seq, cluster_seq)
559        updated_beam_set.append(updated_beam_state)
560      beam_set = updated_beam_set
561    predicted_cluster_id = beam_set[0].trace[-test_sequence_length:]
562    return predicted_cluster_id
563
564  def predict(self, test_sequences, args):
565    """Predict labels for a single or many test sequences using UISRNN model.
566
567    Args:
568      test_sequences: Either a list of test sequences, or a single test
569        sequence. Each test sequence is a 2-dim numpy array
570        of real numbers. See `predict_single()` for details.
571      args: Inference configurations. See `arguments.py` for details.
572
573    Returns:
574      predicted_cluster_ids: Predicted labels for test_sequences.
575
576        1. if test_sequences is a list, predicted_cluster_ids will be a list
577           of the same size, where each element being a 1-dim list of strings.
578        2. if test_sequences is a single sequence, predicted_cluster_ids will
579           be a 1-dim list of strings
580
581    Raises:
582      TypeError: If test_sequences is of wrong type.
583    """
584    # check type
585    if isinstance(test_sequences, np.ndarray):
586      return self.predict_single(test_sequences, args)
587    if isinstance(test_sequences, list):
588      return [self.predict_single(test_sequence, args)
589              for test_sequence in test_sequences]
590    raise TypeError('test_sequences should be either a list or numpy array.')

Unbounded Interleaved-State Recurrent Neural Networks.

UISRNN(args)
 83  def __init__(self, args):
 84    """Construct the UISRNN object.
 85
 86    Args:
 87      args: Model configurations. See `arguments.py` for details.
 88    """
 89    self.observation_dim = args.observation_dim
 90    self.device = torch.device(
 91        'cuda:0' if (torch.cuda.is_available() and args.enable_cuda) else 'cpu')
 92    self.rnn_model = CoreRNN(self.observation_dim, args.rnn_hidden_size,
 93                             args.rnn_depth, self.observation_dim,
 94                             args.rnn_dropout).to(self.device)
 95    self.rnn_init_hidden = nn.Parameter(
 96        torch.zeros(args.rnn_depth, 1, args.rnn_hidden_size).to(self.device))
 97    # booleans indicating which variables are trainable
 98    self.estimate_sigma2 = (args.sigma2 is None)
 99    self.estimate_transition_bias = (args.transition_bias is None)
100    # initial values of variables
101    sigma2 = _INITIAL_SIGMA2_VALUE if self.estimate_sigma2 else args.sigma2
102    self.sigma2 = nn.Parameter(
103        sigma2 * torch.ones(self.observation_dim).to(self.device))
104    self.transition_bias = args.transition_bias
105    self.transition_bias_denominator = 0.0
106    self.crp_alpha = args.crp_alpha
107    self.logger = utils.Logger(args.verbosity)

Construct the UISRNN object.

Args: args: Model configurations. See arguments.py for details.

observation_dim
device
rnn_model
rnn_init_hidden
estimate_sigma2
estimate_transition_bias
sigma2
transition_bias
transition_bias_denominator
crp_alpha
logger
def save(self, filepath):
135  def save(self, filepath):
136    """Save the model to a file.
137
138    Args:
139      filepath: the path of the file.
140    """
141    torch.save({
142        'rnn_state_dict': self.rnn_model.state_dict(),
143        'rnn_init_hidden': self.rnn_init_hidden.detach().cpu().numpy(),
144        'transition_bias': self.transition_bias,
145        'transition_bias_denominator': self.transition_bias_denominator,
146        'crp_alpha': self.crp_alpha,
147        'sigma2': self.sigma2.detach().cpu().numpy()}, filepath)

Save the model to a file.

Args: filepath: the path of the file.

def load(self, filepath):
149  def load(self, filepath):
150    """Load the model from a file.
151
152    Args:
153      filepath: the path of the file.
154    """
155    var_dict = torch.load(filepath)
156    self.rnn_model.load_state_dict(var_dict['rnn_state_dict'])
157    self.rnn_init_hidden = nn.Parameter(
158        torch.from_numpy(var_dict['rnn_init_hidden']).to(self.device))
159    self.transition_bias = float(var_dict['transition_bias'])
160    self.transition_bias_denominator = float(
161        var_dict['transition_bias_denominator'])
162    self.crp_alpha = float(var_dict['crp_alpha'])
163    self.sigma2 = nn.Parameter(
164        torch.from_numpy(var_dict['sigma2']).to(self.device))
165
166    self.logger.print(
167        3, 'Loaded model with transition_bias={}, crp_alpha={}, sigma2={}, '
168        'rnn_init_hidden={}'.format(
169            self.transition_bias, self.crp_alpha, var_dict['sigma2'],
170            var_dict['rnn_init_hidden']))

Load the model from a file.

Args: filepath: the path of the file.

def fit_concatenated(self, train_sequence, train_cluster_id, args):
172  def fit_concatenated(self, train_sequence, train_cluster_id, args):
173    """Fit UISRNN model to concatenated sequence and cluster_id.
174
175    Args:
176      train_sequence: the training observation sequence, which is a
177        2-dim numpy array of real numbers, of size `N * D`.
178
179        - `N`: summation of lengths of all utterances.
180        - `D`: observation dimension.
181
182        For example,
183      ```
184      train_sequence =
185      [[1.2 3.0 -4.1 6.0]    --> an entry of speaker #0 from utterance 'iaaa'
186       [0.8 -1.1 0.4 0.5]    --> an entry of speaker #1 from utterance 'iaaa'
187       [-0.2 1.0 3.8 5.7]    --> an entry of speaker #0 from utterance 'iaaa'
188       [3.8 -0.1 1.5 2.3]    --> an entry of speaker #0 from utterance 'ibbb'
189       [1.2 1.4 3.6 -2.7]]   --> an entry of speaker #0 from utterance 'ibbb'
190      ```
191        Here `N=5`, `D=4`.
192
193        We concatenate all training utterances into this single sequence.
194      train_cluster_id: the speaker id sequence, which is 1-dim list or
195        numpy array of strings, of size `N`.
196        For example,
197      ```
198      train_cluster_id =
199        ['iaaa_0', 'iaaa_1', 'iaaa_0', 'ibbb_0', 'ibbb_0']
200      ```
201        'iaaa_0' means the entry belongs to speaker #0 in utterance 'iaaa'.
202
203        Note that the order of entries within an utterance are preserved,
204        and all utterances are simply concatenated together.
205      args: Training configurations. See `arguments.py` for details.
206
207    Raises:
208      TypeError: If train_sequence or train_cluster_id is of wrong type.
209      ValueError: If train_sequence or train_cluster_id has wrong dimension.
210    """
211    # check type
212    if (not isinstance(train_sequence, np.ndarray) or
213        train_sequence.dtype != float):
214      raise TypeError('train_sequence should be a numpy array of float type.')
215    if isinstance(train_cluster_id, list):
216      train_cluster_id = np.array(train_cluster_id)
217    if (not isinstance(train_cluster_id, np.ndarray) or
218        not train_cluster_id.dtype.name.startswith(('str', 'unicode'))):
219      raise TypeError('train_cluster_id type be a numpy array of strings.')
220    # check dimension
221    if train_sequence.ndim != 2:
222      raise ValueError('train_sequence must be 2-dim array.')
223    if train_cluster_id.ndim != 1:
224      raise ValueError('train_cluster_id must be 1-dim array.')
225    # check length and size
226    train_total_length, observation_dim = train_sequence.shape
227    if observation_dim != self.observation_dim:
228      raise ValueError('train_sequence does not match the dimension specified '
229                       'by args.observation_dim.')
230    if train_total_length != len(train_cluster_id):
231      raise ValueError('train_sequence length is not equal to '
232                       'train_cluster_id length.')
233
234    self.rnn_model.train()
235    optimizer = self._get_optimizer(optimizer=args.optimizer,
236                                    learning_rate=args.learning_rate)
237
238    sub_sequences, seq_lengths = utils.resize_sequence(
239        sequence=train_sequence,
240        cluster_id=train_cluster_id,
241        num_permutations=args.num_permutations)
242
243    # For batch learning, pack the entire dataset.
244    if args.batch_size is None:
245      packed_train_sequence, rnn_truth = utils.pack_sequence(
246          sub_sequences,
247          seq_lengths,
248          args.batch_size,
249          self.observation_dim,
250          self.device)
251    train_loss = []
252    for num_iter in range(args.train_iteration):
253      optimizer.zero_grad()
254      # For online learning, pack a subset in each iteration.
255      if args.batch_size is not None:
256        packed_train_sequence, rnn_truth = utils.pack_sequence(
257            sub_sequences,
258            seq_lengths,
259            args.batch_size,
260            self.observation_dim,
261            self.device)
262      hidden = self.rnn_init_hidden.repeat(1, args.batch_size, 1)
263      mean, _ = self.rnn_model(packed_train_sequence, hidden)
264      # use mean to predict
265      mean = torch.cumsum(mean, dim=0)
266      mean_size = mean.size()
267      mean = torch.mm(
268          torch.diag(
269              1.0 / torch.arange(1, mean_size[0] + 1).float().to(self.device)),
270          mean.view(mean_size[0], -1))
271      mean = mean.view(mean_size)
272
273      # Likelihood part.
274      loss1 = loss_func.weighted_mse_loss(
275          input_tensor=(rnn_truth != 0).float() * mean[:-1, :, :],
276          target_tensor=rnn_truth,
277          weight=1 / (2 * self.sigma2))
278
279      # Sigma2 prior part.
280      weight = (((rnn_truth != 0).float() * mean[:-1, :, :] - rnn_truth)
281                ** 2).view(-1, observation_dim)
282      num_non_zero = torch.sum((weight != 0).float(), dim=0).squeeze()
283      loss2 = loss_func.sigma2_prior_loss(
284          num_non_zero, args.sigma_alpha, args.sigma_beta, self.sigma2)
285
286      # Regularization part.
287      loss3 = loss_func.regularization_loss(
288          self.rnn_model.parameters(), args.regularization_weight)
289
290      loss = loss1 + loss2 + loss3
291      loss.backward()
292      nn.utils.clip_grad_norm_(self.rnn_model.parameters(), args.grad_max_norm)
293      optimizer.step()
294      # avoid numerical issues
295      self.sigma2.data.clamp_(min=1e-6)
296
297      if (np.remainder(num_iter, 10) == 0 or
298          num_iter == args.train_iteration - 1):
299        self.logger.print(
300            2,
301            'Iter: {:d}  \t'
302            'Training Loss: {:.4f}    \n'
303            '    Negative Log Likelihood: {:.4f}\t'
304            'Sigma2 Prior: {:.4f}\t'
305            'Regularization: {:.4f}'.format(
306                num_iter,
307                float(loss.data),
308                float(loss1.data),
309                float(loss2.data),
310                float(loss3.data)))
311      train_loss.append(float(loss1.data))  # only save the likelihood part
312    self.logger.print(
313        1, 'Done training with {} iterations'.format(args.train_iteration))

Fit UISRNN model to concatenated sequence and cluster_id.

Args: train_sequence: the training observation sequence, which is a 2-dim numpy array of real numbers, of size N * D.

- `N`: summation of lengths of all utterances.
- `D`: observation dimension.

For example,

train_sequence =
[[1.2 3.0 -4.1 6.0]    --> an entry of speaker #0 from utterance 'iaaa'
 [0.8 -1.1 0.4 0.5]    --> an entry of speaker #1 from utterance 'iaaa'
 [-0.2 1.0 3.8 5.7]    --> an entry of speaker #0 from utterance 'iaaa'
 [3.8 -0.1 1.5 2.3]    --> an entry of speaker #0 from utterance 'ibbb'
 [1.2 1.4 3.6 -2.7]]   --> an entry of speaker #0 from utterance 'ibbb'

Here `N=5`, `D=4`.
We concatenate all training utterances into this single sequence.

train_cluster_id: the speaker id sequence, which is 1-dim list or numpy array of strings, of size N. For example,
train_cluster_id =
  ['iaaa_0', 'iaaa_1', 'iaaa_0', 'ibbb_0', 'ibbb_0']
'iaaa_0' means the entry belongs to speaker #0 in utterance 'iaaa'.

Note that the order of entries within an utterance are preserved,
and all utterances are simply concatenated together.

args: Training configurations. See arguments.py for details.

Raises: TypeError: If train_sequence or train_cluster_id is of wrong type. ValueError: If train_sequence or train_cluster_id has wrong dimension.

def fit(self, train_sequences, train_cluster_ids, args):
315  def fit(self, train_sequences, train_cluster_ids, args):
316    """Fit UISRNN model.
317
318    Args:
319      train_sequences: Either a list of training sequences, or a single
320        concatenated training sequence:
321
322        1. train_sequences is list, and each element is a 2-dim numpy array
323           of real numbers, of size: `length * D`.
324           The length varies among different sequences, but the D is the same.
325           In speaker diarization, each sequence is the sequence of speaker
326           embeddings of one utterance.
327        2. train_sequences is a single concatenated sequence, which is a
328           2-dim numpy array of real numbers. See `fit_concatenated()`
329           for more details.
330      train_cluster_ids: Ground truth labels for train_sequences:
331
332        1. if train_sequences is a list, this must also be a list of the same
333           size, each element being a 1-dim list or numpy array of strings.
334        2. if train_sequences is a single concatenated sequence, this
335           must also be the concatenated 1-dim list or numpy array of strings
336      args: Training configurations. See `arguments.py` for details.
337
338    Raises:
339      TypeError: If train_sequences or train_cluster_ids is of wrong type.
340    """
341    if isinstance(train_sequences, np.ndarray):
342      # train_sequences is already the concatenated sequence
343      if self.estimate_transition_bias:
344        # see issue #55: https://github.com/google/uis-rnn/issues/55
345        self.logger.print(
346            2,
347            'Warning: transition_bias cannot be correctly estimated from a '
348            'concatenated sequence; train_sequences will be treated as a '
349            'single sequence. This can lead to inaccurate estimation of '
350            'transition_bias. Please, consider estimating transition_bias '
351            'before concatenating the sequences and passing it as argument.')
352      train_sequences = [train_sequences]
353      train_cluster_ids = [train_cluster_ids]
354    elif isinstance(train_sequences, list):
355      # train_sequences is a list of un-concatenated sequences
356      # we will concatenate it later, after estimating transition_bias
357      pass
358    else:
359      raise TypeError('train_sequences must be a list or numpy.ndarray')
360
361    # estimate transition_bias
362    if self.estimate_transition_bias:
363      (transition_bias,
364       transition_bias_denominator) = utils.estimate_transition_bias(
365           train_cluster_ids)
366      # set or update transition_bias
367      if self.transition_bias is None:
368        self.transition_bias = transition_bias
369        self.transition_bias_denominator = transition_bias_denominator
370      else:
371        self.transition_bias = (
372            self.transition_bias * self.transition_bias_denominator +
373            transition_bias * transition_bias_denominator) / (
374                self.transition_bias_denominator + transition_bias_denominator)
375        self.transition_bias_denominator += transition_bias_denominator
376
377    # concatenate train_sequences
378    (concatenated_train_sequence,
379     concatenated_train_cluster_id) = utils.concatenate_training_data(
380         train_sequences,
381         train_cluster_ids,
382         args.enforce_cluster_id_uniqueness,
383         True)
384
385    self.fit_concatenated(
386        concatenated_train_sequence, concatenated_train_cluster_id, args)

Fit UISRNN model.

Args: train_sequences: Either a list of training sequences, or a single concatenated training sequence:

1. train_sequences is list, and each element is a 2-dim numpy array
   of real numbers, of size: `length * D`.
   The length varies among different sequences, but the D is the same.
   In speaker diarization, each sequence is the sequence of speaker
   embeddings of one utterance.
2. train_sequences is a single concatenated sequence, which is a
   2-dim numpy array of real numbers. See `fit_concatenated()`
   for more details.

train_cluster_ids: Ground truth labels for train_sequences:

1. if train_sequences is a list, this must also be a list of the same
   size, each element being a 1-dim list or numpy array of strings.
2. if train_sequences is a single concatenated sequence, this
   must also be the concatenated 1-dim list or numpy array of strings

args: Training configurations. See arguments.py for details.

Raises: TypeError: If train_sequences or train_cluster_ids is of wrong type.

def predict_single(self, test_sequence, args):
479  def predict_single(self, test_sequence, args):
480    """Predict labels for a single test sequence using UISRNN model.
481
482    Args:
483      test_sequence: the test observation sequence, which is 2-dim numpy array
484        of real numbers, of size `N * D`.
485
486        - `N`: length of one test utterance.
487        - `D` : observation dimension.
488
489        For example:
490      ```
491      test_sequence =
492      [[2.2 -1.0 3.0 5.6]    --> 1st entry of utterance 'iccc'
493       [0.5 1.8 -3.2 0.4]    --> 2nd entry of utterance 'iccc'
494       [-2.2 5.0 1.8 3.7]    --> 3rd entry of utterance 'iccc'
495       [-3.8 0.1 1.4 3.3]    --> 4th entry of utterance 'iccc'
496       [0.1 2.7 3.5 -1.7]]   --> 5th entry of utterance 'iccc'
497      ```
498        Here `N=5`, `D=4`.
499      args: Inference configurations. See `arguments.py` for details.
500
501    Returns:
502      predicted_cluster_id: predicted speaker id sequence, which is
503        an array of integers, of size `N`.
504        For example, `predicted_cluster_id = [0, 1, 0, 0, 1]`
505
506    Raises:
507      TypeError: If test_sequence is of wrong type.
508      ValueError: If test_sequence has wrong dimension.
509    """
510    # check type
511    if (not isinstance(test_sequence, np.ndarray) or
512        test_sequence.dtype != float):
513      raise TypeError('test_sequence should be a numpy array of float type.')
514    # check dimension
515    if test_sequence.ndim != 2:
516      raise ValueError('test_sequence must be 2-dim array.')
517    # check size
518    test_sequence_length, observation_dim = test_sequence.shape
519    if observation_dim != self.observation_dim:
520      raise ValueError('test_sequence does not match the dimension specified '
521                       'by args.observation_dim.')
522
523    self.rnn_model.eval()
524    test_sequence = np.tile(test_sequence, (args.test_iteration, 1))
525    test_sequence = autograd.Variable(
526        torch.from_numpy(test_sequence).float()).to(self.device)
527    # bookkeeping for beam search
528    beam_set = [BeamState()]
529    for num_iter in np.arange(0, args.test_iteration * test_sequence_length,
530                              args.look_ahead):
531      max_clusters = max([len(beam_state.mean_set) for beam_state in beam_set])
532      look_ahead_seq = test_sequence[num_iter:  num_iter + args.look_ahead, :]
533      look_ahead_seq_length = look_ahead_seq.shape[0]
534      score_set = float('inf') * np.ones(
535          np.append(
536              args.beam_size, max_clusters + 1 + np.arange(
537                  look_ahead_seq_length)))
538      for beam_rank, beam_state in enumerate(beam_set):
539        beam_score_set = self._calculate_score(beam_state, look_ahead_seq)
540        score_set[beam_rank, :] = np.pad(
541            beam_score_set,
542            np.tile([[0, max_clusters - len(beam_state.mean_set)]],
543                    (look_ahead_seq_length, 1)), 'constant',
544            constant_values=float('inf'))
545      # find top scores
546      score_ranked = np.sort(score_set, axis=None)
547      score_ranked[score_ranked == float('inf')] = 0
548      score_ranked = np.trim_zeros(score_ranked)
549      idx_ranked = np.argsort(score_set, axis=None)
550      updated_beam_set = []
551      for new_beam_rank in range(
552          np.min((len(score_ranked), args.beam_size))):
553        total_idx = np.unravel_index(idx_ranked[new_beam_rank],
554                                     score_set.shape)
555        prev_beam_rank = total_idx[0].item()
556        cluster_seq = total_idx[1:]
557        updated_beam_state = self._update_beam_state(
558            beam_set[prev_beam_rank], look_ahead_seq, cluster_seq)
559        updated_beam_set.append(updated_beam_state)
560      beam_set = updated_beam_set
561    predicted_cluster_id = beam_set[0].trace[-test_sequence_length:]
562    return predicted_cluster_id

Predict labels for a single test sequence using UISRNN model.

Args: test_sequence: the test observation sequence, which is 2-dim numpy array of real numbers, of size N * D.

- `N`: length of one test utterance.
- `D` : observation dimension.

For example:

test_sequence =
[[2.2 -1.0 3.0 5.6]    --> 1st entry of utterance 'iccc'
 [0.5 1.8 -3.2 0.4]    --> 2nd entry of utterance 'iccc'
 [-2.2 5.0 1.8 3.7]    --> 3rd entry of utterance 'iccc'
 [-3.8 0.1 1.4 3.3]    --> 4th entry of utterance 'iccc'
 [0.1 2.7 3.5 -1.7]]   --> 5th entry of utterance 'iccc'
Here N=5, D=4. args: Inference configurations. See arguments.py for details.

Returns: predicted_cluster_id: predicted speaker id sequence, which is an array of integers, of size N. For example, predicted_cluster_id = [0, 1, 0, 0, 1]

Raises: TypeError: If test_sequence is of wrong type. ValueError: If test_sequence has wrong dimension.

def predict(self, test_sequences, args):
564  def predict(self, test_sequences, args):
565    """Predict labels for a single or many test sequences using UISRNN model.
566
567    Args:
568      test_sequences: Either a list of test sequences, or a single test
569        sequence. Each test sequence is a 2-dim numpy array
570        of real numbers. See `predict_single()` for details.
571      args: Inference configurations. See `arguments.py` for details.
572
573    Returns:
574      predicted_cluster_ids: Predicted labels for test_sequences.
575
576        1. if test_sequences is a list, predicted_cluster_ids will be a list
577           of the same size, where each element being a 1-dim list of strings.
578        2. if test_sequences is a single sequence, predicted_cluster_ids will
579           be a 1-dim list of strings
580
581    Raises:
582      TypeError: If test_sequences is of wrong type.
583    """
584    # check type
585    if isinstance(test_sequences, np.ndarray):
586      return self.predict_single(test_sequences, args)
587    if isinstance(test_sequences, list):
588      return [self.predict_single(test_sequence, args)
589              for test_sequence in test_sequences]
590    raise TypeError('test_sequences should be either a list or numpy array.')

Predict labels for a single or many test sequences using UISRNN model.

Args: test_sequences: Either a list of test sequences, or a single test sequence. Each test sequence is a 2-dim numpy array of real numbers. See predict_single() for details. args: Inference configurations. See arguments.py for details.

Returns: predicted_cluster_ids: Predicted labels for test_sequences.

1. if test_sequences is a list, predicted_cluster_ids will be a list
   of the same size, where each element being a 1-dim list of strings.
2. if test_sequences is a single sequence, predicted_cluster_ids will
   be a 1-dim list of strings

Raises: TypeError: If test_sequences is of wrong type.

def parallel_predict(model, test_sequences, args, num_processes=4):
593def parallel_predict(model, test_sequences, args, num_processes=4):
594  """Run prediction in parallel using torch.multiprocessing.
595
596  This is a beta feature. It makes prediction slower on CPU. But it's reported
597  that it makes prediction faster on GPU.
598
599  Args:
600    model: instance of UISRNN model
601    test_sequences: a list of test sequences, or a single test
602      sequence. Each test sequence is a 2-dim numpy array
603      of real numbers. See `predict_single()` for details.
604    args: Inference configurations. See `arguments.py` for details.
605    num_processes: number of parallel processes.
606
607  Returns:
608    a list of the same size as test_sequences, where each element
609    being a 1-dim list of strings.
610
611  Raises:
612      TypeError: If test_sequences is of wrong type.
613  """
614  if not isinstance(test_sequences, list):
615    raise TypeError('test_sequences must be a list.')
616  ctx = multiprocessing.get_context('forkserver')
617  model.rnn_model.share_memory()
618  pool = ctx.Pool(num_processes)
619  results = pool.map(
620      functools.partial(model.predict_single, args=args),
621      test_sequences)
622  pool.close()
623  return results

Run prediction in parallel using torch.multiprocessing.

This is a beta feature. It makes prediction slower on CPU. But it's reported that it makes prediction faster on GPU.

Args: model: instance of UISRNN model test_sequences: a list of test sequences, or a single test sequence. Each test sequence is a 2-dim numpy array of real numbers. See predict_single() for details. args: Inference configurations. See arguments.py for details. num_processes: number of parallel processes.

Returns: a list of the same size as test_sequences, where each element being a 1-dim list of strings.

Raises: TypeError: If test_sequences is of wrong type.