uisrnn.uisrnn

The UIS-RNN model.

  1# Copyright 2018 Google LLC
  2#
  3# Licensed under the Apache License, Version 2.0 (the "License");
  4# you may not use this file except in compliance with the License.
  5# You may obtain a copy of the License at
  6#
  7#     https://www.apache.org/licenses/LICENSE-2.0
  8#
  9# Unless required by applicable law or agreed to in writing, software
 10# distributed under the License is distributed on an "AS IS" BASIS,
 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12# See the License for the specific language governing permissions and
 13# limitations under the License.
 14"""The UIS-RNN model."""
 15
 16import colortimelog
 17import functools
 18import numpy as np
 19import torch
 20from torch import autograd
 21from torch import multiprocessing
 22from torch import nn
 23from torch import optim
 24import torch.nn.functional as F
 25
 26from uisrnn import loss_func
 27from uisrnn import utils
 28
 29_INITIAL_SIGMA2_VALUE = 0.1
 30
 31
 32class CoreRNN(nn.Module):
 33  """The core Recurent Neural Network used by UIS-RNN."""
 34
 35  def __init__(self, input_dim, hidden_size, depth, observation_dim, dropout=0):
 36    super().__init__()
 37    self.hidden_size = hidden_size
 38    if depth >= 2:
 39      self.gru = nn.GRU(input_dim, hidden_size, depth, dropout=dropout)
 40    else:
 41      self.gru = nn.GRU(input_dim, hidden_size, depth)
 42    self.linear_mean1 = nn.Linear(hidden_size, hidden_size)
 43    self.linear_mean2 = nn.Linear(hidden_size, observation_dim)
 44
 45  def forward(self, input_seq, hidden=None):
 46    """The forward function of the module."""
 47    output_seq, hidden = self.gru(input_seq, hidden)
 48    if isinstance(output_seq, torch.nn.utils.rnn.PackedSequence):
 49      output_seq, _ = torch.nn.utils.rnn.pad_packed_sequence(
 50          output_seq, batch_first=False)
 51    mean = self.linear_mean2(F.relu(self.linear_mean1(output_seq)))
 52    return mean, hidden
 53
 54
 55class BeamState:
 56  """Structure that contains necessary states for beam search."""
 57
 58  def __init__(self, source=None):
 59    if not source:
 60      self.mean_set = []
 61      self.hidden_set = []
 62      self.neg_likelihood = 0
 63      self.trace = []
 64      self.block_counts = []
 65    else:
 66      self.mean_set = source.mean_set.copy()
 67      self.hidden_set = source.hidden_set.copy()
 68      self.trace = source.trace.copy()
 69      self.block_counts = source.block_counts.copy()
 70      self.neg_likelihood = source.neg_likelihood
 71
 72  def append(self, mean, hidden, cluster):
 73    """Append new item to the BeamState."""
 74    self.mean_set.append(mean.clone())
 75    self.hidden_set.append(hidden.clone())
 76    self.block_counts.append(1)
 77    self.trace.append(cluster)
 78
 79
 80class UISRNN:
 81  """Unbounded Interleaved-State Recurrent Neural Networks."""
 82
 83  def __init__(self, args):
 84    """Construct the UISRNN object.
 85
 86    Args:
 87      args: Model configurations. See `arguments.py` for details.
 88    """
 89    self.observation_dim = args.observation_dim
 90    self.device = torch.device(
 91        'cuda:0' if (torch.cuda.is_available() and args.enable_cuda) else 'cpu')
 92    self.rnn_model = CoreRNN(self.observation_dim, args.rnn_hidden_size,
 93                             args.rnn_depth, self.observation_dim,
 94                             args.rnn_dropout).to(self.device)
 95    self.rnn_init_hidden = nn.Parameter(
 96        torch.zeros(args.rnn_depth, 1, args.rnn_hidden_size).to(self.device))
 97    # booleans indicating which variables are trainable
 98    self.estimate_sigma2 = (args.sigma2 is None)
 99    self.estimate_transition_bias = (args.transition_bias is None)
100    # initial values of variables
101    sigma2 = _INITIAL_SIGMA2_VALUE if self.estimate_sigma2 else args.sigma2
102    self.sigma2 = nn.Parameter(
103        sigma2 * torch.ones(self.observation_dim).to(self.device))
104    self.transition_bias = args.transition_bias
105    self.transition_bias_denominator = 0.0
106    self.crp_alpha = args.crp_alpha
107    self.logger = colortimelog.Logger(args.verbosity)
108
109  def _get_optimizer(self, optimizer, learning_rate):
110    """Get optimizer for UISRNN.
111
112    Args:
113      optimizer: string - name of the optimizer.
114      learning_rate: - learning rate for the entire model.
115        We do not customize learning rate for separate parts.
116
117    Returns:
118      a pytorch "optim" object
119    """
120    params = [
121        {
122            'params': self.rnn_model.parameters()
123        },  # rnn parameters
124        {
125            'params': self.rnn_init_hidden
126        }  # rnn initial hidden state
127    ]
128    if self.estimate_sigma2:  # train sigma2
129      params.append({
130          'params': self.sigma2
131      })  # variance parameters
132    assert optimizer == 'adam', 'Only adam optimizer is supported.'
133    return optim.Adam(params, lr=learning_rate)
134
135  def save(self, filepath):
136    """Save the model to a file.
137
138    Args:
139      filepath: the path of the file.
140    """
141    torch.save({
142        'rnn_state_dict': self.rnn_model.state_dict(),
143        'rnn_init_hidden': self.rnn_init_hidden.detach().cpu().numpy(),
144        'transition_bias': self.transition_bias,
145        'transition_bias_denominator': self.transition_bias_denominator,
146        'crp_alpha': self.crp_alpha,
147        'sigma2': self.sigma2.detach().cpu().numpy()}, filepath)
148
149  def load(self, filepath):
150    """Load the model from a file.
151
152    Args:
153      filepath: the path of the file.
154    """
155    var_dict = torch.load(filepath)
156    self.rnn_model.load_state_dict(var_dict['rnn_state_dict'])
157    self.rnn_init_hidden = nn.Parameter(
158        torch.from_numpy(var_dict['rnn_init_hidden']).to(self.device))
159    self.transition_bias = float(var_dict['transition_bias'])
160    self.transition_bias_denominator = float(
161        var_dict['transition_bias_denominator'])
162    self.crp_alpha = float(var_dict['crp_alpha'])
163    self.sigma2 = nn.Parameter(
164        torch.from_numpy(var_dict['sigma2']).to(self.device))
165
166    self.logger.print(
167        3, 'Loaded model with transition_bias={}, crp_alpha={}, sigma2={}, '
168        'rnn_init_hidden={}'.format(
169            self.transition_bias, self.crp_alpha, var_dict['sigma2'],
170            var_dict['rnn_init_hidden']))
171
172  def fit_concatenated(self, train_sequence, train_cluster_id, args):
173    """Fit UISRNN model to concatenated sequence and cluster_id.
174
175    Args:
176      train_sequence: the training observation sequence, which is a
177        2-dim numpy array of real numbers, of size `N * D`.
178
179        - `N`: summation of lengths of all utterances.
180        - `D`: observation dimension.
181
182        For example,
183      ```
184      train_sequence =
185      [[1.2 3.0 -4.1 6.0]    --> an entry of speaker #0 from utterance 'iaaa'
186       [0.8 -1.1 0.4 0.5]    --> an entry of speaker #1 from utterance 'iaaa'
187       [-0.2 1.0 3.8 5.7]    --> an entry of speaker #0 from utterance 'iaaa'
188       [3.8 -0.1 1.5 2.3]    --> an entry of speaker #0 from utterance 'ibbb'
189       [1.2 1.4 3.6 -2.7]]   --> an entry of speaker #0 from utterance 'ibbb'
190      ```
191        Here `N=5`, `D=4`.
192
193        We concatenate all training utterances into this single sequence.
194      train_cluster_id: the speaker id sequence, which is 1-dim list or
195        numpy array of strings, of size `N`.
196        For example,
197      ```
198      train_cluster_id =
199        ['iaaa_0', 'iaaa_1', 'iaaa_0', 'ibbb_0', 'ibbb_0']
200      ```
201        'iaaa_0' means the entry belongs to speaker #0 in utterance 'iaaa'.
202
203        Note that the order of entries within an utterance are preserved,
204        and all utterances are simply concatenated together.
205      args: Training configurations. See `arguments.py` for details.
206
207    Raises:
208      TypeError: If train_sequence or train_cluster_id is of wrong type.
209      ValueError: If train_sequence or train_cluster_id has wrong dimension.
210    """
211    # check type
212    if (not isinstance(train_sequence, np.ndarray) or
213        train_sequence.dtype != float):
214      raise TypeError('train_sequence should be a numpy array of float type.')
215    if isinstance(train_cluster_id, list):
216      train_cluster_id = np.array(train_cluster_id)
217    if (not isinstance(train_cluster_id, np.ndarray) or
218        not train_cluster_id.dtype.name.startswith(('str', 'unicode'))):
219      raise TypeError('train_cluster_id type be a numpy array of strings.')
220    # check dimension
221    if train_sequence.ndim != 2:
222      raise ValueError('train_sequence must be 2-dim array.')
223    if train_cluster_id.ndim != 1:
224      raise ValueError('train_cluster_id must be 1-dim array.')
225    # check length and size
226    train_total_length, observation_dim = train_sequence.shape
227    if observation_dim != self.observation_dim:
228      raise ValueError('train_sequence does not match the dimension specified '
229                       'by args.observation_dim.')
230    if train_total_length != len(train_cluster_id):
231      raise ValueError('train_sequence length is not equal to '
232                       'train_cluster_id length.')
233
234    self.rnn_model.train()
235    optimizer = self._get_optimizer(optimizer=args.optimizer,
236                                    learning_rate=args.learning_rate)
237
238    sub_sequences, seq_lengths = utils.resize_sequence(
239        sequence=train_sequence,
240        cluster_id=train_cluster_id,
241        num_permutations=args.num_permutations)
242
243    # For batch learning, pack the entire dataset.
244    if args.batch_size is None:
245      packed_train_sequence, rnn_truth = utils.pack_sequence(
246          sub_sequences,
247          seq_lengths,
248          args.batch_size,
249          self.observation_dim,
250          self.device)
251    train_loss = []
252    for num_iter in range(args.train_iteration):
253      optimizer.zero_grad()
254      # For online learning, pack a subset in each iteration.
255      if args.batch_size is not None:
256        packed_train_sequence, rnn_truth = utils.pack_sequence(
257            sub_sequences,
258            seq_lengths,
259            args.batch_size,
260            self.observation_dim,
261            self.device)
262      hidden = self.rnn_init_hidden.repeat(1, args.batch_size, 1)
263      mean, _ = self.rnn_model(packed_train_sequence, hidden)
264      # use mean to predict
265      mean = torch.cumsum(mean, dim=0)
266      mean_size = mean.size()
267      mean = torch.mm(
268          torch.diag(
269              1.0 / torch.arange(1, mean_size[0] + 1).float().to(self.device)),
270          mean.view(mean_size[0], -1))
271      mean = mean.view(mean_size)
272
273      # Likelihood part.
274      loss1 = loss_func.weighted_mse_loss(
275          input_tensor=(rnn_truth != 0).float() * mean[:-1, :, :],
276          target_tensor=rnn_truth,
277          weight=1 / (2 * self.sigma2))
278
279      # Sigma2 prior part.
280      weight = (((rnn_truth != 0).float() * mean[:-1, :, :] - rnn_truth)
281                ** 2).view(-1, observation_dim)
282      num_non_zero = torch.sum((weight != 0).float(), dim=0).squeeze()
283      loss2 = loss_func.sigma2_prior_loss(
284          num_non_zero, args.sigma_alpha, args.sigma_beta, self.sigma2)
285
286      # Regularization part.
287      loss3 = loss_func.regularization_loss(
288          self.rnn_model.parameters(), args.regularization_weight)
289
290      loss = loss1 + loss2 + loss3
291      loss.backward()
292      nn.utils.clip_grad_norm_(self.rnn_model.parameters(), args.grad_max_norm)
293      optimizer.step()
294      # avoid numerical issues
295      self.sigma2.data.clamp_(min=1e-6)
296
297      if (np.remainder(num_iter, 10) == 0 or
298          num_iter == args.train_iteration - 1):
299        self.logger.print(
300            2,
301            'Iter: {:d}  \t'
302            'Training Loss: {:.4f}    \n'
303            '    Negative Log Likelihood: {:.4f}\t'
304            'Sigma2 Prior: {:.4f}\t'
305            'Regularization: {:.4f}'.format(
306                num_iter,
307                float(loss.data),
308                float(loss1.data),
309                float(loss2.data),
310                float(loss3.data)))
311      train_loss.append(float(loss1.data))  # only save the likelihood part
312    self.logger.print(
313        1, 'Done training with {} iterations'.format(args.train_iteration))
314
315  def fit(self, train_sequences, train_cluster_ids, args):
316    """Fit UISRNN model.
317
318    Args:
319      train_sequences: Either a list of training sequences, or a single
320        concatenated training sequence:
321
322        1. train_sequences is list, and each element is a 2-dim numpy array
323           of real numbers, of size: `length * D`.
324           The length varies among different sequences, but the D is the same.
325           In speaker diarization, each sequence is the sequence of speaker
326           embeddings of one utterance.
327        2. train_sequences is a single concatenated sequence, which is a
328           2-dim numpy array of real numbers. See `fit_concatenated()`
329           for more details.
330      train_cluster_ids: Ground truth labels for train_sequences:
331
332        1. if train_sequences is a list, this must also be a list of the same
333           size, each element being a 1-dim list or numpy array of strings.
334        2. if train_sequences is a single concatenated sequence, this
335           must also be the concatenated 1-dim list or numpy array of strings
336      args: Training configurations. See `arguments.py` for details.
337
338    Raises:
339      TypeError: If train_sequences or train_cluster_ids is of wrong type.
340    """
341    if isinstance(train_sequences, np.ndarray):
342      # train_sequences is already the concatenated sequence
343      if self.estimate_transition_bias:
344        # see issue #55: https://github.com/google/uis-rnn/issues/55
345        self.logger.print(
346            2,
347            'Warning: transition_bias cannot be correctly estimated from a '
348            'concatenated sequence; train_sequences will be treated as a '
349            'single sequence. This can lead to inaccurate estimation of '
350            'transition_bias. Please, consider estimating transition_bias '
351            'before concatenating the sequences and passing it as argument.')
352      train_sequences = [train_sequences]
353      train_cluster_ids = [train_cluster_ids]
354    elif isinstance(train_sequences, list):
355      # train_sequences is a list of un-concatenated sequences
356      # we will concatenate it later, after estimating transition_bias
357      pass
358    else:
359      raise TypeError('train_sequences must be a list or numpy.ndarray')
360
361    # estimate transition_bias
362    if self.estimate_transition_bias:
363      (transition_bias,
364       transition_bias_denominator) = utils.estimate_transition_bias(
365           train_cluster_ids)
366      # set or update transition_bias
367      if self.transition_bias is None:
368        self.transition_bias = transition_bias
369        self.transition_bias_denominator = transition_bias_denominator
370      else:
371        self.transition_bias = (
372            self.transition_bias * self.transition_bias_denominator +
373            transition_bias * transition_bias_denominator) / (
374                self.transition_bias_denominator + transition_bias_denominator)
375        self.transition_bias_denominator += transition_bias_denominator
376
377    # concatenate train_sequences
378    (concatenated_train_sequence,
379     concatenated_train_cluster_id) = utils.concatenate_training_data(
380         train_sequences,
381         train_cluster_ids,
382         args.enforce_cluster_id_uniqueness,
383         True)
384
385    self.fit_concatenated(
386        concatenated_train_sequence, concatenated_train_cluster_id, args)
387
388  def _update_beam_state(self, beam_state, look_ahead_seq, cluster_seq):
389    """Update a beam state given a look ahead sequence and known cluster
390    assignments.
391
392    Args:
393      beam_state: A BeamState object.
394      look_ahead_seq: Look ahead sequence, size: look_ahead*D.
395        look_ahead: number of step to look ahead in the beam search.
396        D: observation dimension
397      cluster_seq: Cluster assignment sequence for look_ahead_seq.
398
399    Returns:
400      new_beam_state: An updated BeamState object.
401    """
402
403    loss = 0
404    new_beam_state = BeamState(beam_state)
405    for sub_idx, cluster in enumerate(cluster_seq):
406      if cluster > len(new_beam_state.mean_set):  # invalid trace
407        new_beam_state.neg_likelihood = float('inf')
408        break
409      elif cluster < len(new_beam_state.mean_set):  # existing cluster
410        last_cluster = new_beam_state.trace[-1]
411        loss = loss_func.weighted_mse_loss(
412            input_tensor=torch.squeeze(new_beam_state.mean_set[cluster]),
413            target_tensor=look_ahead_seq[sub_idx, :],
414            weight=1 / (2 * self.sigma2)).cpu().detach().numpy()
415        if cluster == last_cluster:
416          loss -= np.log(1 - self.transition_bias)
417        else:
418          loss -= np.log(self.transition_bias) + np.log(
419              new_beam_state.block_counts[cluster]) - np.log(
420                  sum(new_beam_state.block_counts) + self.crp_alpha)
421        # update new mean and new hidden
422        mean, hidden = self.rnn_model(
423            look_ahead_seq[sub_idx, :].unsqueeze(0).unsqueeze(0),
424            new_beam_state.hidden_set[cluster])
425        new_beam_state.mean_set[cluster] = (new_beam_state.mean_set[cluster]*(
426            (np.array(new_beam_state.trace) == cluster).sum() -
427            1).astype(float) + mean.clone()) / (
428                np.array(new_beam_state.trace) == cluster).sum().astype(
429                    float)  # use mean to predict
430        new_beam_state.hidden_set[cluster] = hidden.clone()
431        if cluster != last_cluster:
432          new_beam_state.block_counts[cluster] += 1
433        new_beam_state.trace.append(cluster)
434      else:  # new cluster
435        init_input = autograd.Variable(
436            torch.zeros(self.observation_dim)
437        ).unsqueeze(0).unsqueeze(0).to(self.device)
438        mean, hidden = self.rnn_model(init_input,
439                                      self.rnn_init_hidden)
440        loss = loss_func.weighted_mse_loss(
441            input_tensor=torch.squeeze(mean),
442            target_tensor=look_ahead_seq[sub_idx, :],
443            weight=1 / (2 * self.sigma2)).cpu().detach().numpy()
444        loss -= np.log(self.transition_bias) + np.log(
445            self.crp_alpha) - np.log(
446                sum(new_beam_state.block_counts) + self.crp_alpha)
447        # update new min and new hidden
448        mean, hidden = self.rnn_model(
449            look_ahead_seq[sub_idx, :].unsqueeze(0).unsqueeze(0),
450            hidden)
451        new_beam_state.append(mean, hidden, cluster)
452      new_beam_state.neg_likelihood += loss
453    return new_beam_state
454
455  def _calculate_score(self, beam_state, look_ahead_seq):
456    """Calculate negative log likelihoods for all possible state allocations
457       of a look ahead sequence, according to the current beam state.
458
459    Args:
460      beam_state: A BeamState object.
461      look_ahead_seq: Look ahead sequence, size: look_ahead*D.
462        look_ahead: number of step to look ahead in the beam search.
463        D: observation dimension
464
465    Returns:
466      beam_score_set: a set of scores for each possible state allocation.
467    """
468
469    look_ahead, _ = look_ahead_seq.shape
470    beam_num_clusters = len(beam_state.mean_set)
471    beam_score_set = float('inf') * np.ones(
472        beam_num_clusters + 1 + np.arange(look_ahead))
473    for cluster_seq, _ in np.ndenumerate(beam_score_set):
474      updated_beam_state = self._update_beam_state(beam_state,
475                                                   look_ahead_seq, cluster_seq)
476      beam_score_set[cluster_seq] = updated_beam_state.neg_likelihood
477    return beam_score_set
478
479  def predict_single(self, test_sequence, args):
480    """Predict labels for a single test sequence using UISRNN model.
481
482    Args:
483      test_sequence: the test observation sequence, which is 2-dim numpy array
484        of real numbers, of size `N * D`.
485
486        - `N`: length of one test utterance.
487        - `D` : observation dimension.
488
489        For example:
490      ```
491      test_sequence =
492      [[2.2 -1.0 3.0 5.6]    --> 1st entry of utterance 'iccc'
493       [0.5 1.8 -3.2 0.4]    --> 2nd entry of utterance 'iccc'
494       [-2.2 5.0 1.8 3.7]    --> 3rd entry of utterance 'iccc'
495       [-3.8 0.1 1.4 3.3]    --> 4th entry of utterance 'iccc'
496       [0.1 2.7 3.5 -1.7]]   --> 5th entry of utterance 'iccc'
497      ```
498        Here `N=5`, `D=4`.
499      args: Inference configurations. See `arguments.py` for details.
500
501    Returns:
502      predicted_cluster_id: predicted speaker id sequence, which is
503        an array of integers, of size `N`.
504        For example, `predicted_cluster_id = [0, 1, 0, 0, 1]`
505
506    Raises:
507      TypeError: If test_sequence is of wrong type.
508      ValueError: If test_sequence has wrong dimension.
509    """
510    # check type
511    if (not isinstance(test_sequence, np.ndarray) or
512        test_sequence.dtype != float):
513      raise TypeError('test_sequence should be a numpy array of float type.')
514    # check dimension
515    if test_sequence.ndim != 2:
516      raise ValueError('test_sequence must be 2-dim array.')
517    # check size
518    test_sequence_length, observation_dim = test_sequence.shape
519    if observation_dim != self.observation_dim:
520      raise ValueError('test_sequence does not match the dimension specified '
521                       'by args.observation_dim.')
522
523    self.rnn_model.eval()
524    test_sequence = np.tile(test_sequence, (args.test_iteration, 1))
525    test_sequence = autograd.Variable(
526        torch.from_numpy(test_sequence).float()).to(self.device)
527    # bookkeeping for beam search
528    beam_set = [BeamState()]
529    for num_iter in np.arange(0, args.test_iteration * test_sequence_length,
530                              args.look_ahead):
531      max_clusters = max([len(beam_state.mean_set) for beam_state in beam_set])
532      look_ahead_seq = test_sequence[num_iter:  num_iter + args.look_ahead, :]
533      look_ahead_seq_length = look_ahead_seq.shape[0]
534      score_set = float('inf') * np.ones(
535          np.append(
536              args.beam_size, max_clusters + 1 + np.arange(
537                  look_ahead_seq_length)))
538      for beam_rank, beam_state in enumerate(beam_set):
539        beam_score_set = self._calculate_score(beam_state, look_ahead_seq)
540        score_set[beam_rank, :] = np.pad(
541            beam_score_set,
542            np.tile([[0, max_clusters - len(beam_state.mean_set)]],
543                    (look_ahead_seq_length, 1)), 'constant',
544            constant_values=float('inf'))
545      # find top scores
546      score_ranked = np.sort(score_set, axis=None)
547      score_ranked[score_ranked == float('inf')] = 0
548      score_ranked = np.trim_zeros(score_ranked)
549      idx_ranked = np.argsort(score_set, axis=None)
550      updated_beam_set = []
551      for new_beam_rank in range(
552          np.min((len(score_ranked), args.beam_size))):
553        total_idx = np.unravel_index(idx_ranked[new_beam_rank],
554                                     score_set.shape)
555        prev_beam_rank = total_idx[0].item()
556        cluster_seq = total_idx[1:]
557        updated_beam_state = self._update_beam_state(
558            beam_set[prev_beam_rank], look_ahead_seq, cluster_seq)
559        updated_beam_set.append(updated_beam_state)
560      beam_set = updated_beam_set
561    predicted_cluster_id = beam_set[0].trace[-test_sequence_length:]
562    return predicted_cluster_id
563
564  def predict(self, test_sequences, args):
565    """Predict labels for a single or many test sequences using UISRNN model.
566
567    Args:
568      test_sequences: Either a list of test sequences, or a single test
569        sequence. Each test sequence is a 2-dim numpy array
570        of real numbers. See `predict_single()` for details.
571      args: Inference configurations. See `arguments.py` for details.
572
573    Returns:
574      predicted_cluster_ids: Predicted labels for test_sequences.
575
576        1. if test_sequences is a list, predicted_cluster_ids will be a list
577           of the same size, where each element being a 1-dim list of strings.
578        2. if test_sequences is a single sequence, predicted_cluster_ids will
579           be a 1-dim list of strings
580
581    Raises:
582      TypeError: If test_sequences is of wrong type.
583    """
584    # check type
585    if isinstance(test_sequences, np.ndarray):
586      return self.predict_single(test_sequences, args)
587    if isinstance(test_sequences, list):
588      return [self.predict_single(test_sequence, args)
589              for test_sequence in test_sequences]
590    raise TypeError('test_sequences should be either a list or numpy array.')
591
592
593def parallel_predict(model, test_sequences, args, num_processes=4):
594  """Run prediction in parallel using torch.multiprocessing.
595
596  This is a beta feature. It makes prediction slower on CPU. But it's reported
597  that it makes prediction faster on GPU.
598
599  Args:
600    model: instance of UISRNN model
601    test_sequences: a list of test sequences, or a single test
602      sequence. Each test sequence is a 2-dim numpy array
603      of real numbers. See `predict_single()` for details.
604    args: Inference configurations. See `arguments.py` for details.
605    num_processes: number of parallel processes.
606
607  Returns:
608    a list of the same size as test_sequences, where each element
609    being a 1-dim list of strings.
610
611  Raises:
612      TypeError: If test_sequences is of wrong type.
613  """
614  if not isinstance(test_sequences, list):
615    raise TypeError('test_sequences must be a list.')
616  ctx = multiprocessing.get_context('forkserver')
617  model.rnn_model.share_memory()
618  pool = ctx.Pool(num_processes)
619  results = pool.map(
620      functools.partial(model.predict_single, args=args),
621      test_sequences)
622  pool.close()
623  return results
class CoreRNN(torch.nn.modules.module.Module):
33class CoreRNN(nn.Module):
34  """The core Recurent Neural Network used by UIS-RNN."""
35
36  def __init__(self, input_dim, hidden_size, depth, observation_dim, dropout=0):
37    super().__init__()
38    self.hidden_size = hidden_size
39    if depth >= 2:
40      self.gru = nn.GRU(input_dim, hidden_size, depth, dropout=dropout)
41    else:
42      self.gru = nn.GRU(input_dim, hidden_size, depth)
43    self.linear_mean1 = nn.Linear(hidden_size, hidden_size)
44    self.linear_mean2 = nn.Linear(hidden_size, observation_dim)
45
46  def forward(self, input_seq, hidden=None):
47    """The forward function of the module."""
48    output_seq, hidden = self.gru(input_seq, hidden)
49    if isinstance(output_seq, torch.nn.utils.rnn.PackedSequence):
50      output_seq, _ = torch.nn.utils.rnn.pad_packed_sequence(
51          output_seq, batch_first=False)
52    mean = self.linear_mean2(F.relu(self.linear_mean1(output_seq)))
53    return mean, hidden

The core Recurent Neural Network used by UIS-RNN.

CoreRNN(input_dim, hidden_size, depth, observation_dim, dropout=0)
36  def __init__(self, input_dim, hidden_size, depth, observation_dim, dropout=0):
37    super().__init__()
38    self.hidden_size = hidden_size
39    if depth >= 2:
40      self.gru = nn.GRU(input_dim, hidden_size, depth, dropout=dropout)
41    else:
42      self.gru = nn.GRU(input_dim, hidden_size, depth)
43    self.linear_mean1 = nn.Linear(hidden_size, hidden_size)
44    self.linear_mean2 = nn.Linear(hidden_size, observation_dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

hidden_size
linear_mean1
linear_mean2
def forward(self, input_seq, hidden=None):
46  def forward(self, input_seq, hidden=None):
47    """The forward function of the module."""
48    output_seq, hidden = self.gru(input_seq, hidden)
49    if isinstance(output_seq, torch.nn.utils.rnn.PackedSequence):
50      output_seq, _ = torch.nn.utils.rnn.pad_packed_sequence(
51          output_seq, batch_first=False)
52    mean = self.linear_mean2(F.relu(self.linear_mean1(output_seq)))
53    return mean, hidden

The forward function of the module.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class BeamState:
56class BeamState:
57  """Structure that contains necessary states for beam search."""
58
59  def __init__(self, source=None):
60    if not source:
61      self.mean_set = []
62      self.hidden_set = []
63      self.neg_likelihood = 0
64      self.trace = []
65      self.block_counts = []
66    else:
67      self.mean_set = source.mean_set.copy()
68      self.hidden_set = source.hidden_set.copy()
69      self.trace = source.trace.copy()
70      self.block_counts = source.block_counts.copy()
71      self.neg_likelihood = source.neg_likelihood
72
73  def append(self, mean, hidden, cluster):
74    """Append new item to the BeamState."""
75    self.mean_set.append(mean.clone())
76    self.hidden_set.append(hidden.clone())
77    self.block_counts.append(1)
78    self.trace.append(cluster)

Structure that contains necessary states for beam search.

BeamState(source=None)
59  def __init__(self, source=None):
60    if not source:
61      self.mean_set = []
62      self.hidden_set = []
63      self.neg_likelihood = 0
64      self.trace = []
65      self.block_counts = []
66    else:
67      self.mean_set = source.mean_set.copy()
68      self.hidden_set = source.hidden_set.copy()
69      self.trace = source.trace.copy()
70      self.block_counts = source.block_counts.copy()
71      self.neg_likelihood = source.neg_likelihood
def append(self, mean, hidden, cluster):
73  def append(self, mean, hidden, cluster):
74    """Append new item to the BeamState."""
75    self.mean_set.append(mean.clone())
76    self.hidden_set.append(hidden.clone())
77    self.block_counts.append(1)
78    self.trace.append(cluster)

Append new item to the BeamState.

class UISRNN:
 81class UISRNN:
 82  """Unbounded Interleaved-State Recurrent Neural Networks."""
 83
 84  def __init__(self, args):
 85    """Construct the UISRNN object.
 86
 87    Args:
 88      args: Model configurations. See `arguments.py` for details.
 89    """
 90    self.observation_dim = args.observation_dim
 91    self.device = torch.device(
 92        'cuda:0' if (torch.cuda.is_available() and args.enable_cuda) else 'cpu')
 93    self.rnn_model = CoreRNN(self.observation_dim, args.rnn_hidden_size,
 94                             args.rnn_depth, self.observation_dim,
 95                             args.rnn_dropout).to(self.device)
 96    self.rnn_init_hidden = nn.Parameter(
 97        torch.zeros(args.rnn_depth, 1, args.rnn_hidden_size).to(self.device))
 98    # booleans indicating which variables are trainable
 99    self.estimate_sigma2 = (args.sigma2 is None)
100    self.estimate_transition_bias = (args.transition_bias is None)
101    # initial values of variables
102    sigma2 = _INITIAL_SIGMA2_VALUE if self.estimate_sigma2 else args.sigma2
103    self.sigma2 = nn.Parameter(
104        sigma2 * torch.ones(self.observation_dim).to(self.device))
105    self.transition_bias = args.transition_bias
106    self.transition_bias_denominator = 0.0
107    self.crp_alpha = args.crp_alpha
108    self.logger = colortimelog.Logger(args.verbosity)
109
110  def _get_optimizer(self, optimizer, learning_rate):
111    """Get optimizer for UISRNN.
112
113    Args:
114      optimizer: string - name of the optimizer.
115      learning_rate: - learning rate for the entire model.
116        We do not customize learning rate for separate parts.
117
118    Returns:
119      a pytorch "optim" object
120    """
121    params = [
122        {
123            'params': self.rnn_model.parameters()
124        },  # rnn parameters
125        {
126            'params': self.rnn_init_hidden
127        }  # rnn initial hidden state
128    ]
129    if self.estimate_sigma2:  # train sigma2
130      params.append({
131          'params': self.sigma2
132      })  # variance parameters
133    assert optimizer == 'adam', 'Only adam optimizer is supported.'
134    return optim.Adam(params, lr=learning_rate)
135
136  def save(self, filepath):
137    """Save the model to a file.
138
139    Args:
140      filepath: the path of the file.
141    """
142    torch.save({
143        'rnn_state_dict': self.rnn_model.state_dict(),
144        'rnn_init_hidden': self.rnn_init_hidden.detach().cpu().numpy(),
145        'transition_bias': self.transition_bias,
146        'transition_bias_denominator': self.transition_bias_denominator,
147        'crp_alpha': self.crp_alpha,
148        'sigma2': self.sigma2.detach().cpu().numpy()}, filepath)
149
150  def load(self, filepath):
151    """Load the model from a file.
152
153    Args:
154      filepath: the path of the file.
155    """
156    var_dict = torch.load(filepath)
157    self.rnn_model.load_state_dict(var_dict['rnn_state_dict'])
158    self.rnn_init_hidden = nn.Parameter(
159        torch.from_numpy(var_dict['rnn_init_hidden']).to(self.device))
160    self.transition_bias = float(var_dict['transition_bias'])
161    self.transition_bias_denominator = float(
162        var_dict['transition_bias_denominator'])
163    self.crp_alpha = float(var_dict['crp_alpha'])
164    self.sigma2 = nn.Parameter(
165        torch.from_numpy(var_dict['sigma2']).to(self.device))
166
167    self.logger.print(
168        3, 'Loaded model with transition_bias={}, crp_alpha={}, sigma2={}, '
169        'rnn_init_hidden={}'.format(
170            self.transition_bias, self.crp_alpha, var_dict['sigma2'],
171            var_dict['rnn_init_hidden']))
172
173  def fit_concatenated(self, train_sequence, train_cluster_id, args):
174    """Fit UISRNN model to concatenated sequence and cluster_id.
175
176    Args:
177      train_sequence: the training observation sequence, which is a
178        2-dim numpy array of real numbers, of size `N * D`.
179
180        - `N`: summation of lengths of all utterances.
181        - `D`: observation dimension.
182
183        For example,
184      ```
185      train_sequence =
186      [[1.2 3.0 -4.1 6.0]    --> an entry of speaker #0 from utterance 'iaaa'
187       [0.8 -1.1 0.4 0.5]    --> an entry of speaker #1 from utterance 'iaaa'
188       [-0.2 1.0 3.8 5.7]    --> an entry of speaker #0 from utterance 'iaaa'
189       [3.8 -0.1 1.5 2.3]    --> an entry of speaker #0 from utterance 'ibbb'
190       [1.2 1.4 3.6 -2.7]]   --> an entry of speaker #0 from utterance 'ibbb'
191      ```
192        Here `N=5`, `D=4`.
193
194        We concatenate all training utterances into this single sequence.
195      train_cluster_id: the speaker id sequence, which is 1-dim list or
196        numpy array of strings, of size `N`.
197        For example,
198      ```
199      train_cluster_id =
200        ['iaaa_0', 'iaaa_1', 'iaaa_0', 'ibbb_0', 'ibbb_0']
201      ```
202        'iaaa_0' means the entry belongs to speaker #0 in utterance 'iaaa'.
203
204        Note that the order of entries within an utterance are preserved,
205        and all utterances are simply concatenated together.
206      args: Training configurations. See `arguments.py` for details.
207
208    Raises:
209      TypeError: If train_sequence or train_cluster_id is of wrong type.
210      ValueError: If train_sequence or train_cluster_id has wrong dimension.
211    """
212    # check type
213    if (not isinstance(train_sequence, np.ndarray) or
214        train_sequence.dtype != float):
215      raise TypeError('train_sequence should be a numpy array of float type.')
216    if isinstance(train_cluster_id, list):
217      train_cluster_id = np.array(train_cluster_id)
218    if (not isinstance(train_cluster_id, np.ndarray) or
219        not train_cluster_id.dtype.name.startswith(('str', 'unicode'))):
220      raise TypeError('train_cluster_id type be a numpy array of strings.')
221    # check dimension
222    if train_sequence.ndim != 2:
223      raise ValueError('train_sequence must be 2-dim array.')
224    if train_cluster_id.ndim != 1:
225      raise ValueError('train_cluster_id must be 1-dim array.')
226    # check length and size
227    train_total_length, observation_dim = train_sequence.shape
228    if observation_dim != self.observation_dim:
229      raise ValueError('train_sequence does not match the dimension specified '
230                       'by args.observation_dim.')
231    if train_total_length != len(train_cluster_id):
232      raise ValueError('train_sequence length is not equal to '
233                       'train_cluster_id length.')
234
235    self.rnn_model.train()
236    optimizer = self._get_optimizer(optimizer=args.optimizer,
237                                    learning_rate=args.learning_rate)
238
239    sub_sequences, seq_lengths = utils.resize_sequence(
240        sequence=train_sequence,
241        cluster_id=train_cluster_id,
242        num_permutations=args.num_permutations)
243
244    # For batch learning, pack the entire dataset.
245    if args.batch_size is None:
246      packed_train_sequence, rnn_truth = utils.pack_sequence(
247          sub_sequences,
248          seq_lengths,
249          args.batch_size,
250          self.observation_dim,
251          self.device)
252    train_loss = []
253    for num_iter in range(args.train_iteration):
254      optimizer.zero_grad()
255      # For online learning, pack a subset in each iteration.
256      if args.batch_size is not None:
257        packed_train_sequence, rnn_truth = utils.pack_sequence(
258            sub_sequences,
259            seq_lengths,
260            args.batch_size,
261            self.observation_dim,
262            self.device)
263      hidden = self.rnn_init_hidden.repeat(1, args.batch_size, 1)
264      mean, _ = self.rnn_model(packed_train_sequence, hidden)
265      # use mean to predict
266      mean = torch.cumsum(mean, dim=0)
267      mean_size = mean.size()
268      mean = torch.mm(
269          torch.diag(
270              1.0 / torch.arange(1, mean_size[0] + 1).float().to(self.device)),
271          mean.view(mean_size[0], -1))
272      mean = mean.view(mean_size)
273
274      # Likelihood part.
275      loss1 = loss_func.weighted_mse_loss(
276          input_tensor=(rnn_truth != 0).float() * mean[:-1, :, :],
277          target_tensor=rnn_truth,
278          weight=1 / (2 * self.sigma2))
279
280      # Sigma2 prior part.
281      weight = (((rnn_truth != 0).float() * mean[:-1, :, :] - rnn_truth)
282                ** 2).view(-1, observation_dim)
283      num_non_zero = torch.sum((weight != 0).float(), dim=0).squeeze()
284      loss2 = loss_func.sigma2_prior_loss(
285          num_non_zero, args.sigma_alpha, args.sigma_beta, self.sigma2)
286
287      # Regularization part.
288      loss3 = loss_func.regularization_loss(
289          self.rnn_model.parameters(), args.regularization_weight)
290
291      loss = loss1 + loss2 + loss3
292      loss.backward()
293      nn.utils.clip_grad_norm_(self.rnn_model.parameters(), args.grad_max_norm)
294      optimizer.step()
295      # avoid numerical issues
296      self.sigma2.data.clamp_(min=1e-6)
297
298      if (np.remainder(num_iter, 10) == 0 or
299          num_iter == args.train_iteration - 1):
300        self.logger.print(
301            2,
302            'Iter: {:d}  \t'
303            'Training Loss: {:.4f}    \n'
304            '    Negative Log Likelihood: {:.4f}\t'
305            'Sigma2 Prior: {:.4f}\t'
306            'Regularization: {:.4f}'.format(
307                num_iter,
308                float(loss.data),
309                float(loss1.data),
310                float(loss2.data),
311                float(loss3.data)))
312      train_loss.append(float(loss1.data))  # only save the likelihood part
313    self.logger.print(
314        1, 'Done training with {} iterations'.format(args.train_iteration))
315
316  def fit(self, train_sequences, train_cluster_ids, args):
317    """Fit UISRNN model.
318
319    Args:
320      train_sequences: Either a list of training sequences, or a single
321        concatenated training sequence:
322
323        1. train_sequences is list, and each element is a 2-dim numpy array
324           of real numbers, of size: `length * D`.
325           The length varies among different sequences, but the D is the same.
326           In speaker diarization, each sequence is the sequence of speaker
327           embeddings of one utterance.
328        2. train_sequences is a single concatenated sequence, which is a
329           2-dim numpy array of real numbers. See `fit_concatenated()`
330           for more details.
331      train_cluster_ids: Ground truth labels for train_sequences:
332
333        1. if train_sequences is a list, this must also be a list of the same
334           size, each element being a 1-dim list or numpy array of strings.
335        2. if train_sequences is a single concatenated sequence, this
336           must also be the concatenated 1-dim list or numpy array of strings
337      args: Training configurations. See `arguments.py` for details.
338
339    Raises:
340      TypeError: If train_sequences or train_cluster_ids is of wrong type.
341    """
342    if isinstance(train_sequences, np.ndarray):
343      # train_sequences is already the concatenated sequence
344      if self.estimate_transition_bias:
345        # see issue #55: https://github.com/google/uis-rnn/issues/55
346        self.logger.print(
347            2,
348            'Warning: transition_bias cannot be correctly estimated from a '
349            'concatenated sequence; train_sequences will be treated as a '
350            'single sequence. This can lead to inaccurate estimation of '
351            'transition_bias. Please, consider estimating transition_bias '
352            'before concatenating the sequences and passing it as argument.')
353      train_sequences = [train_sequences]
354      train_cluster_ids = [train_cluster_ids]
355    elif isinstance(train_sequences, list):
356      # train_sequences is a list of un-concatenated sequences
357      # we will concatenate it later, after estimating transition_bias
358      pass
359    else:
360      raise TypeError('train_sequences must be a list or numpy.ndarray')
361
362    # estimate transition_bias
363    if self.estimate_transition_bias:
364      (transition_bias,
365       transition_bias_denominator) = utils.estimate_transition_bias(
366           train_cluster_ids)
367      # set or update transition_bias
368      if self.transition_bias is None:
369        self.transition_bias = transition_bias
370        self.transition_bias_denominator = transition_bias_denominator
371      else:
372        self.transition_bias = (
373            self.transition_bias * self.transition_bias_denominator +
374            transition_bias * transition_bias_denominator) / (
375                self.transition_bias_denominator + transition_bias_denominator)
376        self.transition_bias_denominator += transition_bias_denominator
377
378    # concatenate train_sequences
379    (concatenated_train_sequence,
380     concatenated_train_cluster_id) = utils.concatenate_training_data(
381         train_sequences,
382         train_cluster_ids,
383         args.enforce_cluster_id_uniqueness,
384         True)
385
386    self.fit_concatenated(
387        concatenated_train_sequence, concatenated_train_cluster_id, args)
388
389  def _update_beam_state(self, beam_state, look_ahead_seq, cluster_seq):
390    """Update a beam state given a look ahead sequence and known cluster
391    assignments.
392
393    Args:
394      beam_state: A BeamState object.
395      look_ahead_seq: Look ahead sequence, size: look_ahead*D.
396        look_ahead: number of step to look ahead in the beam search.
397        D: observation dimension
398      cluster_seq: Cluster assignment sequence for look_ahead_seq.
399
400    Returns:
401      new_beam_state: An updated BeamState object.
402    """
403
404    loss = 0
405    new_beam_state = BeamState(beam_state)
406    for sub_idx, cluster in enumerate(cluster_seq):
407      if cluster > len(new_beam_state.mean_set):  # invalid trace
408        new_beam_state.neg_likelihood = float('inf')
409        break
410      elif cluster < len(new_beam_state.mean_set):  # existing cluster
411        last_cluster = new_beam_state.trace[-1]
412        loss = loss_func.weighted_mse_loss(
413            input_tensor=torch.squeeze(new_beam_state.mean_set[cluster]),
414            target_tensor=look_ahead_seq[sub_idx, :],
415            weight=1 / (2 * self.sigma2)).cpu().detach().numpy()
416        if cluster == last_cluster:
417          loss -= np.log(1 - self.transition_bias)
418        else:
419          loss -= np.log(self.transition_bias) + np.log(
420              new_beam_state.block_counts[cluster]) - np.log(
421                  sum(new_beam_state.block_counts) + self.crp_alpha)
422        # update new mean and new hidden
423        mean, hidden = self.rnn_model(
424            look_ahead_seq[sub_idx, :].unsqueeze(0).unsqueeze(0),
425            new_beam_state.hidden_set[cluster])
426        new_beam_state.mean_set[cluster] = (new_beam_state.mean_set[cluster]*(
427            (np.array(new_beam_state.trace) == cluster).sum() -
428            1).astype(float) + mean.clone()) / (
429                np.array(new_beam_state.trace) == cluster).sum().astype(
430                    float)  # use mean to predict
431        new_beam_state.hidden_set[cluster] = hidden.clone()
432        if cluster != last_cluster:
433          new_beam_state.block_counts[cluster] += 1
434        new_beam_state.trace.append(cluster)
435      else:  # new cluster
436        init_input = autograd.Variable(
437            torch.zeros(self.observation_dim)
438        ).unsqueeze(0).unsqueeze(0).to(self.device)
439        mean, hidden = self.rnn_model(init_input,
440                                      self.rnn_init_hidden)
441        loss = loss_func.weighted_mse_loss(
442            input_tensor=torch.squeeze(mean),
443            target_tensor=look_ahead_seq[sub_idx, :],
444            weight=1 / (2 * self.sigma2)).cpu().detach().numpy()
445        loss -= np.log(self.transition_bias) + np.log(
446            self.crp_alpha) - np.log(
447                sum(new_beam_state.block_counts) + self.crp_alpha)
448        # update new min and new hidden
449        mean, hidden = self.rnn_model(
450            look_ahead_seq[sub_idx, :].unsqueeze(0).unsqueeze(0),
451            hidden)
452        new_beam_state.append(mean, hidden, cluster)
453      new_beam_state.neg_likelihood += loss
454    return new_beam_state
455
456  def _calculate_score(self, beam_state, look_ahead_seq):
457    """Calculate negative log likelihoods for all possible state allocations
458       of a look ahead sequence, according to the current beam state.
459
460    Args:
461      beam_state: A BeamState object.
462      look_ahead_seq: Look ahead sequence, size: look_ahead*D.
463        look_ahead: number of step to look ahead in the beam search.
464        D: observation dimension
465
466    Returns:
467      beam_score_set: a set of scores for each possible state allocation.
468    """
469
470    look_ahead, _ = look_ahead_seq.shape
471    beam_num_clusters = len(beam_state.mean_set)
472    beam_score_set = float('inf') * np.ones(
473        beam_num_clusters + 1 + np.arange(look_ahead))
474    for cluster_seq, _ in np.ndenumerate(beam_score_set):
475      updated_beam_state = self._update_beam_state(beam_state,
476                                                   look_ahead_seq, cluster_seq)
477      beam_score_set[cluster_seq] = updated_beam_state.neg_likelihood
478    return beam_score_set
479
480  def predict_single(self, test_sequence, args):
481    """Predict labels for a single test sequence using UISRNN model.
482
483    Args:
484      test_sequence: the test observation sequence, which is 2-dim numpy array
485        of real numbers, of size `N * D`.
486
487        - `N`: length of one test utterance.
488        - `D` : observation dimension.
489
490        For example:
491      ```
492      test_sequence =
493      [[2.2 -1.0 3.0 5.6]    --> 1st entry of utterance 'iccc'
494       [0.5 1.8 -3.2 0.4]    --> 2nd entry of utterance 'iccc'
495       [-2.2 5.0 1.8 3.7]    --> 3rd entry of utterance 'iccc'
496       [-3.8 0.1 1.4 3.3]    --> 4th entry of utterance 'iccc'
497       [0.1 2.7 3.5 -1.7]]   --> 5th entry of utterance 'iccc'
498      ```
499        Here `N=5`, `D=4`.
500      args: Inference configurations. See `arguments.py` for details.
501
502    Returns:
503      predicted_cluster_id: predicted speaker id sequence, which is
504        an array of integers, of size `N`.
505        For example, `predicted_cluster_id = [0, 1, 0, 0, 1]`
506
507    Raises:
508      TypeError: If test_sequence is of wrong type.
509      ValueError: If test_sequence has wrong dimension.
510    """
511    # check type
512    if (not isinstance(test_sequence, np.ndarray) or
513        test_sequence.dtype != float):
514      raise TypeError('test_sequence should be a numpy array of float type.')
515    # check dimension
516    if test_sequence.ndim != 2:
517      raise ValueError('test_sequence must be 2-dim array.')
518    # check size
519    test_sequence_length, observation_dim = test_sequence.shape
520    if observation_dim != self.observation_dim:
521      raise ValueError('test_sequence does not match the dimension specified '
522                       'by args.observation_dim.')
523
524    self.rnn_model.eval()
525    test_sequence = np.tile(test_sequence, (args.test_iteration, 1))
526    test_sequence = autograd.Variable(
527        torch.from_numpy(test_sequence).float()).to(self.device)
528    # bookkeeping for beam search
529    beam_set = [BeamState()]
530    for num_iter in np.arange(0, args.test_iteration * test_sequence_length,
531                              args.look_ahead):
532      max_clusters = max([len(beam_state.mean_set) for beam_state in beam_set])
533      look_ahead_seq = test_sequence[num_iter:  num_iter + args.look_ahead, :]
534      look_ahead_seq_length = look_ahead_seq.shape[0]
535      score_set = float('inf') * np.ones(
536          np.append(
537              args.beam_size, max_clusters + 1 + np.arange(
538                  look_ahead_seq_length)))
539      for beam_rank, beam_state in enumerate(beam_set):
540        beam_score_set = self._calculate_score(beam_state, look_ahead_seq)
541        score_set[beam_rank, :] = np.pad(
542            beam_score_set,
543            np.tile([[0, max_clusters - len(beam_state.mean_set)]],
544                    (look_ahead_seq_length, 1)), 'constant',
545            constant_values=float('inf'))
546      # find top scores
547      score_ranked = np.sort(score_set, axis=None)
548      score_ranked[score_ranked == float('inf')] = 0
549      score_ranked = np.trim_zeros(score_ranked)
550      idx_ranked = np.argsort(score_set, axis=None)
551      updated_beam_set = []
552      for new_beam_rank in range(
553          np.min((len(score_ranked), args.beam_size))):
554        total_idx = np.unravel_index(idx_ranked[new_beam_rank],
555                                     score_set.shape)
556        prev_beam_rank = total_idx[0].item()
557        cluster_seq = total_idx[1:]
558        updated_beam_state = self._update_beam_state(
559            beam_set[prev_beam_rank], look_ahead_seq, cluster_seq)
560        updated_beam_set.append(updated_beam_state)
561      beam_set = updated_beam_set
562    predicted_cluster_id = beam_set[0].trace[-test_sequence_length:]
563    return predicted_cluster_id
564
565  def predict(self, test_sequences, args):
566    """Predict labels for a single or many test sequences using UISRNN model.
567
568    Args:
569      test_sequences: Either a list of test sequences, or a single test
570        sequence. Each test sequence is a 2-dim numpy array
571        of real numbers. See `predict_single()` for details.
572      args: Inference configurations. See `arguments.py` for details.
573
574    Returns:
575      predicted_cluster_ids: Predicted labels for test_sequences.
576
577        1. if test_sequences is a list, predicted_cluster_ids will be a list
578           of the same size, where each element being a 1-dim list of strings.
579        2. if test_sequences is a single sequence, predicted_cluster_ids will
580           be a 1-dim list of strings
581
582    Raises:
583      TypeError: If test_sequences is of wrong type.
584    """
585    # check type
586    if isinstance(test_sequences, np.ndarray):
587      return self.predict_single(test_sequences, args)
588    if isinstance(test_sequences, list):
589      return [self.predict_single(test_sequence, args)
590              for test_sequence in test_sequences]
591    raise TypeError('test_sequences should be either a list or numpy array.')

Unbounded Interleaved-State Recurrent Neural Networks.

UISRNN(args)
 84  def __init__(self, args):
 85    """Construct the UISRNN object.
 86
 87    Args:
 88      args: Model configurations. See `arguments.py` for details.
 89    """
 90    self.observation_dim = args.observation_dim
 91    self.device = torch.device(
 92        'cuda:0' if (torch.cuda.is_available() and args.enable_cuda) else 'cpu')
 93    self.rnn_model = CoreRNN(self.observation_dim, args.rnn_hidden_size,
 94                             args.rnn_depth, self.observation_dim,
 95                             args.rnn_dropout).to(self.device)
 96    self.rnn_init_hidden = nn.Parameter(
 97        torch.zeros(args.rnn_depth, 1, args.rnn_hidden_size).to(self.device))
 98    # booleans indicating which variables are trainable
 99    self.estimate_sigma2 = (args.sigma2 is None)
100    self.estimate_transition_bias = (args.transition_bias is None)
101    # initial values of variables
102    sigma2 = _INITIAL_SIGMA2_VALUE if self.estimate_sigma2 else args.sigma2
103    self.sigma2 = nn.Parameter(
104        sigma2 * torch.ones(self.observation_dim).to(self.device))
105    self.transition_bias = args.transition_bias
106    self.transition_bias_denominator = 0.0
107    self.crp_alpha = args.crp_alpha
108    self.logger = colortimelog.Logger(args.verbosity)

Construct the UISRNN object.

Args: args: Model configurations. See arguments.py for details.

observation_dim
device
rnn_model
rnn_init_hidden
estimate_sigma2
estimate_transition_bias
sigma2
transition_bias
transition_bias_denominator
crp_alpha
logger
def save(self, filepath):
136  def save(self, filepath):
137    """Save the model to a file.
138
139    Args:
140      filepath: the path of the file.
141    """
142    torch.save({
143        'rnn_state_dict': self.rnn_model.state_dict(),
144        'rnn_init_hidden': self.rnn_init_hidden.detach().cpu().numpy(),
145        'transition_bias': self.transition_bias,
146        'transition_bias_denominator': self.transition_bias_denominator,
147        'crp_alpha': self.crp_alpha,
148        'sigma2': self.sigma2.detach().cpu().numpy()}, filepath)

Save the model to a file.

Args: filepath: the path of the file.

def load(self, filepath):
150  def load(self, filepath):
151    """Load the model from a file.
152
153    Args:
154      filepath: the path of the file.
155    """
156    var_dict = torch.load(filepath)
157    self.rnn_model.load_state_dict(var_dict['rnn_state_dict'])
158    self.rnn_init_hidden = nn.Parameter(
159        torch.from_numpy(var_dict['rnn_init_hidden']).to(self.device))
160    self.transition_bias = float(var_dict['transition_bias'])
161    self.transition_bias_denominator = float(
162        var_dict['transition_bias_denominator'])
163    self.crp_alpha = float(var_dict['crp_alpha'])
164    self.sigma2 = nn.Parameter(
165        torch.from_numpy(var_dict['sigma2']).to(self.device))
166
167    self.logger.print(
168        3, 'Loaded model with transition_bias={}, crp_alpha={}, sigma2={}, '
169        'rnn_init_hidden={}'.format(
170            self.transition_bias, self.crp_alpha, var_dict['sigma2'],
171            var_dict['rnn_init_hidden']))

Load the model from a file.

Args: filepath: the path of the file.

def fit_concatenated(self, train_sequence, train_cluster_id, args):
173  def fit_concatenated(self, train_sequence, train_cluster_id, args):
174    """Fit UISRNN model to concatenated sequence and cluster_id.
175
176    Args:
177      train_sequence: the training observation sequence, which is a
178        2-dim numpy array of real numbers, of size `N * D`.
179
180        - `N`: summation of lengths of all utterances.
181        - `D`: observation dimension.
182
183        For example,
184      ```
185      train_sequence =
186      [[1.2 3.0 -4.1 6.0]    --> an entry of speaker #0 from utterance 'iaaa'
187       [0.8 -1.1 0.4 0.5]    --> an entry of speaker #1 from utterance 'iaaa'
188       [-0.2 1.0 3.8 5.7]    --> an entry of speaker #0 from utterance 'iaaa'
189       [3.8 -0.1 1.5 2.3]    --> an entry of speaker #0 from utterance 'ibbb'
190       [1.2 1.4 3.6 -2.7]]   --> an entry of speaker #0 from utterance 'ibbb'
191      ```
192        Here `N=5`, `D=4`.
193
194        We concatenate all training utterances into this single sequence.
195      train_cluster_id: the speaker id sequence, which is 1-dim list or
196        numpy array of strings, of size `N`.
197        For example,
198      ```
199      train_cluster_id =
200        ['iaaa_0', 'iaaa_1', 'iaaa_0', 'ibbb_0', 'ibbb_0']
201      ```
202        'iaaa_0' means the entry belongs to speaker #0 in utterance 'iaaa'.
203
204        Note that the order of entries within an utterance are preserved,
205        and all utterances are simply concatenated together.
206      args: Training configurations. See `arguments.py` for details.
207
208    Raises:
209      TypeError: If train_sequence or train_cluster_id is of wrong type.
210      ValueError: If train_sequence or train_cluster_id has wrong dimension.
211    """
212    # check type
213    if (not isinstance(train_sequence, np.ndarray) or
214        train_sequence.dtype != float):
215      raise TypeError('train_sequence should be a numpy array of float type.')
216    if isinstance(train_cluster_id, list):
217      train_cluster_id = np.array(train_cluster_id)
218    if (not isinstance(train_cluster_id, np.ndarray) or
219        not train_cluster_id.dtype.name.startswith(('str', 'unicode'))):
220      raise TypeError('train_cluster_id type be a numpy array of strings.')
221    # check dimension
222    if train_sequence.ndim != 2:
223      raise ValueError('train_sequence must be 2-dim array.')
224    if train_cluster_id.ndim != 1:
225      raise ValueError('train_cluster_id must be 1-dim array.')
226    # check length and size
227    train_total_length, observation_dim = train_sequence.shape
228    if observation_dim != self.observation_dim:
229      raise ValueError('train_sequence does not match the dimension specified '
230                       'by args.observation_dim.')
231    if train_total_length != len(train_cluster_id):
232      raise ValueError('train_sequence length is not equal to '
233                       'train_cluster_id length.')
234
235    self.rnn_model.train()
236    optimizer = self._get_optimizer(optimizer=args.optimizer,
237                                    learning_rate=args.learning_rate)
238
239    sub_sequences, seq_lengths = utils.resize_sequence(
240        sequence=train_sequence,
241        cluster_id=train_cluster_id,
242        num_permutations=args.num_permutations)
243
244    # For batch learning, pack the entire dataset.
245    if args.batch_size is None:
246      packed_train_sequence, rnn_truth = utils.pack_sequence(
247          sub_sequences,
248          seq_lengths,
249          args.batch_size,
250          self.observation_dim,
251          self.device)
252    train_loss = []
253    for num_iter in range(args.train_iteration):
254      optimizer.zero_grad()
255      # For online learning, pack a subset in each iteration.
256      if args.batch_size is not None:
257        packed_train_sequence, rnn_truth = utils.pack_sequence(
258            sub_sequences,
259            seq_lengths,
260            args.batch_size,
261            self.observation_dim,
262            self.device)
263      hidden = self.rnn_init_hidden.repeat(1, args.batch_size, 1)
264      mean, _ = self.rnn_model(packed_train_sequence, hidden)
265      # use mean to predict
266      mean = torch.cumsum(mean, dim=0)
267      mean_size = mean.size()
268      mean = torch.mm(
269          torch.diag(
270              1.0 / torch.arange(1, mean_size[0] + 1).float().to(self.device)),
271          mean.view(mean_size[0], -1))
272      mean = mean.view(mean_size)
273
274      # Likelihood part.
275      loss1 = loss_func.weighted_mse_loss(
276          input_tensor=(rnn_truth != 0).float() * mean[:-1, :, :],
277          target_tensor=rnn_truth,
278          weight=1 / (2 * self.sigma2))
279
280      # Sigma2 prior part.
281      weight = (((rnn_truth != 0).float() * mean[:-1, :, :] - rnn_truth)
282                ** 2).view(-1, observation_dim)
283      num_non_zero = torch.sum((weight != 0).float(), dim=0).squeeze()
284      loss2 = loss_func.sigma2_prior_loss(
285          num_non_zero, args.sigma_alpha, args.sigma_beta, self.sigma2)
286
287      # Regularization part.
288      loss3 = loss_func.regularization_loss(
289          self.rnn_model.parameters(), args.regularization_weight)
290
291      loss = loss1 + loss2 + loss3
292      loss.backward()
293      nn.utils.clip_grad_norm_(self.rnn_model.parameters(), args.grad_max_norm)
294      optimizer.step()
295      # avoid numerical issues
296      self.sigma2.data.clamp_(min=1e-6)
297
298      if (np.remainder(num_iter, 10) == 0 or
299          num_iter == args.train_iteration - 1):
300        self.logger.print(
301            2,
302            'Iter: {:d}  \t'
303            'Training Loss: {:.4f}    \n'
304            '    Negative Log Likelihood: {:.4f}\t'
305            'Sigma2 Prior: {:.4f}\t'
306            'Regularization: {:.4f}'.format(
307                num_iter,
308                float(loss.data),
309                float(loss1.data),
310                float(loss2.data),
311                float(loss3.data)))
312      train_loss.append(float(loss1.data))  # only save the likelihood part
313    self.logger.print(
314        1, 'Done training with {} iterations'.format(args.train_iteration))

Fit UISRNN model to concatenated sequence and cluster_id.

Args: train_sequence: the training observation sequence, which is a 2-dim numpy array of real numbers, of size N * D.

- `N`: summation of lengths of all utterances.
- `D`: observation dimension.

For example,

train_sequence =
[[1.2 3.0 -4.1 6.0]    --> an entry of speaker #0 from utterance 'iaaa'
 [0.8 -1.1 0.4 0.5]    --> an entry of speaker #1 from utterance 'iaaa'
 [-0.2 1.0 3.8 5.7]    --> an entry of speaker #0 from utterance 'iaaa'
 [3.8 -0.1 1.5 2.3]    --> an entry of speaker #0 from utterance 'ibbb'
 [1.2 1.4 3.6 -2.7]]   --> an entry of speaker #0 from utterance 'ibbb'

Here `N=5`, `D=4`.

We concatenate all training utterances into this single sequence.

train_cluster_id: the speaker id sequence, which is 1-dim list or numpy array of strings, of size N. For example,

train_cluster_id =
  ['iaaa_0', 'iaaa_1', 'iaaa_0', 'ibbb_0', 'ibbb_0']

'iaaa_0' means the entry belongs to speaker #0 in utterance 'iaaa'.

Note that the order of entries within an utterance are preserved,
and all utterances are simply concatenated together.

args: Training configurations. See arguments.py for details.

Raises: TypeError: If train_sequence or train_cluster_id is of wrong type. ValueError: If train_sequence or train_cluster_id has wrong dimension.

def fit(self, train_sequences, train_cluster_ids, args):
316  def fit(self, train_sequences, train_cluster_ids, args):
317    """Fit UISRNN model.
318
319    Args:
320      train_sequences: Either a list of training sequences, or a single
321        concatenated training sequence:
322
323        1. train_sequences is list, and each element is a 2-dim numpy array
324           of real numbers, of size: `length * D`.
325           The length varies among different sequences, but the D is the same.
326           In speaker diarization, each sequence is the sequence of speaker
327           embeddings of one utterance.
328        2. train_sequences is a single concatenated sequence, which is a
329           2-dim numpy array of real numbers. See `fit_concatenated()`
330           for more details.
331      train_cluster_ids: Ground truth labels for train_sequences:
332
333        1. if train_sequences is a list, this must also be a list of the same
334           size, each element being a 1-dim list or numpy array of strings.
335        2. if train_sequences is a single concatenated sequence, this
336           must also be the concatenated 1-dim list or numpy array of strings
337      args: Training configurations. See `arguments.py` for details.
338
339    Raises:
340      TypeError: If train_sequences or train_cluster_ids is of wrong type.
341    """
342    if isinstance(train_sequences, np.ndarray):
343      # train_sequences is already the concatenated sequence
344      if self.estimate_transition_bias:
345        # see issue #55: https://github.com/google/uis-rnn/issues/55
346        self.logger.print(
347            2,
348            'Warning: transition_bias cannot be correctly estimated from a '
349            'concatenated sequence; train_sequences will be treated as a '
350            'single sequence. This can lead to inaccurate estimation of '
351            'transition_bias. Please, consider estimating transition_bias '
352            'before concatenating the sequences and passing it as argument.')
353      train_sequences = [train_sequences]
354      train_cluster_ids = [train_cluster_ids]
355    elif isinstance(train_sequences, list):
356      # train_sequences is a list of un-concatenated sequences
357      # we will concatenate it later, after estimating transition_bias
358      pass
359    else:
360      raise TypeError('train_sequences must be a list or numpy.ndarray')
361
362    # estimate transition_bias
363    if self.estimate_transition_bias:
364      (transition_bias,
365       transition_bias_denominator) = utils.estimate_transition_bias(
366           train_cluster_ids)
367      # set or update transition_bias
368      if self.transition_bias is None:
369        self.transition_bias = transition_bias
370        self.transition_bias_denominator = transition_bias_denominator
371      else:
372        self.transition_bias = (
373            self.transition_bias * self.transition_bias_denominator +
374            transition_bias * transition_bias_denominator) / (
375                self.transition_bias_denominator + transition_bias_denominator)
376        self.transition_bias_denominator += transition_bias_denominator
377
378    # concatenate train_sequences
379    (concatenated_train_sequence,
380     concatenated_train_cluster_id) = utils.concatenate_training_data(
381         train_sequences,
382         train_cluster_ids,
383         args.enforce_cluster_id_uniqueness,
384         True)
385
386    self.fit_concatenated(
387        concatenated_train_sequence, concatenated_train_cluster_id, args)

Fit UISRNN model.

Args: train_sequences: Either a list of training sequences, or a single concatenated training sequence:

1. train_sequences is list, and each element is a 2-dim numpy array
   of real numbers, of size: `length * D`.
   The length varies among different sequences, but the D is the same.
   In speaker diarization, each sequence is the sequence of speaker
   embeddings of one utterance.
2. train_sequences is a single concatenated sequence, which is a
   2-dim numpy array of real numbers. See `fit_concatenated()`
   for more details.

train_cluster_ids: Ground truth labels for train_sequences:

1. if train_sequences is a list, this must also be a list of the same
   size, each element being a 1-dim list or numpy array of strings.
2. if train_sequences is a single concatenated sequence, this
   must also be the concatenated 1-dim list or numpy array of strings

args: Training configurations. See arguments.py for details.

Raises: TypeError: If train_sequences or train_cluster_ids is of wrong type.

def predict_single(self, test_sequence, args):
480  def predict_single(self, test_sequence, args):
481    """Predict labels for a single test sequence using UISRNN model.
482
483    Args:
484      test_sequence: the test observation sequence, which is 2-dim numpy array
485        of real numbers, of size `N * D`.
486
487        - `N`: length of one test utterance.
488        - `D` : observation dimension.
489
490        For example:
491      ```
492      test_sequence =
493      [[2.2 -1.0 3.0 5.6]    --> 1st entry of utterance 'iccc'
494       [0.5 1.8 -3.2 0.4]    --> 2nd entry of utterance 'iccc'
495       [-2.2 5.0 1.8 3.7]    --> 3rd entry of utterance 'iccc'
496       [-3.8 0.1 1.4 3.3]    --> 4th entry of utterance 'iccc'
497       [0.1 2.7 3.5 -1.7]]   --> 5th entry of utterance 'iccc'
498      ```
499        Here `N=5`, `D=4`.
500      args: Inference configurations. See `arguments.py` for details.
501
502    Returns:
503      predicted_cluster_id: predicted speaker id sequence, which is
504        an array of integers, of size `N`.
505        For example, `predicted_cluster_id = [0, 1, 0, 0, 1]`
506
507    Raises:
508      TypeError: If test_sequence is of wrong type.
509      ValueError: If test_sequence has wrong dimension.
510    """
511    # check type
512    if (not isinstance(test_sequence, np.ndarray) or
513        test_sequence.dtype != float):
514      raise TypeError('test_sequence should be a numpy array of float type.')
515    # check dimension
516    if test_sequence.ndim != 2:
517      raise ValueError('test_sequence must be 2-dim array.')
518    # check size
519    test_sequence_length, observation_dim = test_sequence.shape
520    if observation_dim != self.observation_dim:
521      raise ValueError('test_sequence does not match the dimension specified '
522                       'by args.observation_dim.')
523
524    self.rnn_model.eval()
525    test_sequence = np.tile(test_sequence, (args.test_iteration, 1))
526    test_sequence = autograd.Variable(
527        torch.from_numpy(test_sequence).float()).to(self.device)
528    # bookkeeping for beam search
529    beam_set = [BeamState()]
530    for num_iter in np.arange(0, args.test_iteration * test_sequence_length,
531                              args.look_ahead):
532      max_clusters = max([len(beam_state.mean_set) for beam_state in beam_set])
533      look_ahead_seq = test_sequence[num_iter:  num_iter + args.look_ahead, :]
534      look_ahead_seq_length = look_ahead_seq.shape[0]
535      score_set = float('inf') * np.ones(
536          np.append(
537              args.beam_size, max_clusters + 1 + np.arange(
538                  look_ahead_seq_length)))
539      for beam_rank, beam_state in enumerate(beam_set):
540        beam_score_set = self._calculate_score(beam_state, look_ahead_seq)
541        score_set[beam_rank, :] = np.pad(
542            beam_score_set,
543            np.tile([[0, max_clusters - len(beam_state.mean_set)]],
544                    (look_ahead_seq_length, 1)), 'constant',
545            constant_values=float('inf'))
546      # find top scores
547      score_ranked = np.sort(score_set, axis=None)
548      score_ranked[score_ranked == float('inf')] = 0
549      score_ranked = np.trim_zeros(score_ranked)
550      idx_ranked = np.argsort(score_set, axis=None)
551      updated_beam_set = []
552      for new_beam_rank in range(
553          np.min((len(score_ranked), args.beam_size))):
554        total_idx = np.unravel_index(idx_ranked[new_beam_rank],
555                                     score_set.shape)
556        prev_beam_rank = total_idx[0].item()
557        cluster_seq = total_idx[1:]
558        updated_beam_state = self._update_beam_state(
559            beam_set[prev_beam_rank], look_ahead_seq, cluster_seq)
560        updated_beam_set.append(updated_beam_state)
561      beam_set = updated_beam_set
562    predicted_cluster_id = beam_set[0].trace[-test_sequence_length:]
563    return predicted_cluster_id

Predict labels for a single test sequence using UISRNN model.

Args: test_sequence: the test observation sequence, which is 2-dim numpy array of real numbers, of size N * D.

- `N`: length of one test utterance.
- `D` : observation dimension.

For example:

test_sequence =
[[2.2 -1.0 3.0 5.6]    --> 1st entry of utterance 'iccc'
 [0.5 1.8 -3.2 0.4]    --> 2nd entry of utterance 'iccc'
 [-2.2 5.0 1.8 3.7]    --> 3rd entry of utterance 'iccc'
 [-3.8 0.1 1.4 3.3]    --> 4th entry of utterance 'iccc'
 [0.1 2.7 3.5 -1.7]]   --> 5th entry of utterance 'iccc'

Here `N=5`, `D=4`.

args: Inference configurations. See arguments.py for details.

Returns: predicted_cluster_id: predicted speaker id sequence, which is an array of integers, of size N. For example, predicted_cluster_id = [0, 1, 0, 0, 1]

Raises: TypeError: If test_sequence is of wrong type. ValueError: If test_sequence has wrong dimension.

def predict(self, test_sequences, args):
565  def predict(self, test_sequences, args):
566    """Predict labels for a single or many test sequences using UISRNN model.
567
568    Args:
569      test_sequences: Either a list of test sequences, or a single test
570        sequence. Each test sequence is a 2-dim numpy array
571        of real numbers. See `predict_single()` for details.
572      args: Inference configurations. See `arguments.py` for details.
573
574    Returns:
575      predicted_cluster_ids: Predicted labels for test_sequences.
576
577        1. if test_sequences is a list, predicted_cluster_ids will be a list
578           of the same size, where each element being a 1-dim list of strings.
579        2. if test_sequences is a single sequence, predicted_cluster_ids will
580           be a 1-dim list of strings
581
582    Raises:
583      TypeError: If test_sequences is of wrong type.
584    """
585    # check type
586    if isinstance(test_sequences, np.ndarray):
587      return self.predict_single(test_sequences, args)
588    if isinstance(test_sequences, list):
589      return [self.predict_single(test_sequence, args)
590              for test_sequence in test_sequences]
591    raise TypeError('test_sequences should be either a list or numpy array.')

Predict labels for a single or many test sequences using UISRNN model.

Args: test_sequences: Either a list of test sequences, or a single test sequence. Each test sequence is a 2-dim numpy array of real numbers. See predict_single() for details. args: Inference configurations. See arguments.py for details.

Returns: predicted_cluster_ids: Predicted labels for test_sequences.

1. if test_sequences is a list, predicted_cluster_ids will be a list
   of the same size, where each element being a 1-dim list of strings.
2. if test_sequences is a single sequence, predicted_cluster_ids will
   be a 1-dim list of strings

Raises: TypeError: If test_sequences is of wrong type.

def parallel_predict(model, test_sequences, args, num_processes=4):
594def parallel_predict(model, test_sequences, args, num_processes=4):
595  """Run prediction in parallel using torch.multiprocessing.
596
597  This is a beta feature. It makes prediction slower on CPU. But it's reported
598  that it makes prediction faster on GPU.
599
600  Args:
601    model: instance of UISRNN model
602    test_sequences: a list of test sequences, or a single test
603      sequence. Each test sequence is a 2-dim numpy array
604      of real numbers. See `predict_single()` for details.
605    args: Inference configurations. See `arguments.py` for details.
606    num_processes: number of parallel processes.
607
608  Returns:
609    a list of the same size as test_sequences, where each element
610    being a 1-dim list of strings.
611
612  Raises:
613      TypeError: If test_sequences is of wrong type.
614  """
615  if not isinstance(test_sequences, list):
616    raise TypeError('test_sequences must be a list.')
617  ctx = multiprocessing.get_context('forkserver')
618  model.rnn_model.share_memory()
619  pool = ctx.Pool(num_processes)
620  results = pool.map(
621      functools.partial(model.predict_single, args=args),
622      test_sequences)
623  pool.close()
624  return results

Run prediction in parallel using torch.multiprocessing.

This is a beta feature. It makes prediction slower on CPU. But it's reported that it makes prediction faster on GPU.

Args: model: instance of UISRNN model test_sequences: a list of test sequences, or a single test sequence. Each test sequence is a 2-dim numpy array of real numbers. See predict_single() for details. args: Inference configurations. See arguments.py for details. num_processes: number of parallel processes.

Returns: a list of the same size as test_sequences, where each element being a 1-dim list of strings.

Raises: TypeError: If test_sequences is of wrong type.