uisrnn
The module for Unbounded Interleaved-State Recurrent Neural Network.
An introduction is available at [README.md].
1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""The module for Unbounded Interleaved-State Recurrent Neural Network. 15 16An introduction is available at [README.md]. 17 18[README.md]: https://github.com/google/uis-rnn/blob/master/README.md 19""" 20 21from . import arguments 22from . import evals 23from . import uisrnn 24from . import utils 25 26parse_arguments = arguments.parse_arguments 27compute_sequence_match_accuracy = evals.compute_sequence_match_accuracy 28output_result = utils.output_result 29UISRNN = uisrnn.UISRNN 30parallel_predict = uisrnn.parallel_predict
31def parse_arguments(): 32 """Parse arguments. 33 34 Returns: 35 A tuple of: 36 37 - `model_args`: model arguments 38 - `training_args`: training arguments 39 - `inference_args`: inference arguments 40 """ 41 # model configurations 42 model_parser = argparse.ArgumentParser( 43 description='Model configurations.', add_help=False) 44 45 model_parser.add_argument( 46 '--observation_dim', 47 default=_DEFAULT_OBSERVATION_DIM, 48 type=int, 49 help='The dimension of the embeddings (e.g. d-vectors).') 50 51 model_parser.add_argument( 52 '--rnn_hidden_size', 53 default=512, 54 type=int, 55 help='The number of nodes for each RNN layer.') 56 model_parser.add_argument( 57 '--rnn_depth', 58 default=1, 59 type=int, 60 help='The number of RNN layers.') 61 model_parser.add_argument( 62 '--rnn_dropout', 63 default=0.2, 64 type=float, 65 help='The dropout rate for all RNN layers.') 66 model_parser.add_argument( 67 '--transition_bias', 68 default=None, 69 type=float, 70 help='The value of p0, corresponding to Eq. (6) in the ' 71 'paper. If the value is given, we will fix to this value. If the ' 72 'value is None, we will estimate it from training data ' 73 'using Eq. (13) in the paper.') 74 model_parser.add_argument( 75 '--crp_alpha', 76 default=1.0, 77 type=float, 78 help='The value of alpha for the Chinese restaurant process (CRP), ' 79 'corresponding to Eq. (7) in the paper. In this open source ' 80 'implementation, currently we only support using a given value ' 81 'of crp_alpha.') 82 model_parser.add_argument( 83 '--sigma2', 84 default=None, 85 type=float, 86 help='The value of sigma squared, corresponding to Eq. (11) in the ' 87 'paper. If the value is given, we will fix to this value. If the ' 88 'value is None, we will estimate it from training data.') 89 model_parser.add_argument( 90 '--verbosity', 91 default=3, 92 type=int, 93 help='How verbose will the logging information be. Higher value ' 94 'represents more verbose information. A general guideline: ' 95 '0 for fatals; 1 for errors; 2 for finishing important steps; ' 96 '3 for finishing less important steps; 4 or above for debugging ' 97 'information.') 98 model_parser.add_argument( 99 '--enable_cuda', 100 default=True, 101 type=str2bool, 102 help='Whether we should use CUDA if it is avaiable. If False, we will ' 103 'always use CPU.') 104 105 # training configurations 106 training_parser = argparse.ArgumentParser( 107 description='Training configurations.', add_help=False) 108 109 training_parser.add_argument( 110 '--optimizer', 111 '-o', 112 default='adam', 113 choices=['adam'], 114 help='The optimizer for training.') 115 training_parser.add_argument( 116 '--learning_rate', 117 '-l', 118 default=1e-3, 119 type=float, 120 help='The leaning rate for training.') 121 training_parser.add_argument( 122 '--train_iteration', 123 '-t', 124 default=20000, 125 type=int, 126 help='The total number of training iterations.') 127 training_parser.add_argument( 128 '--batch_size', 129 '-b', 130 default=10, 131 type=int, 132 help='The batch size for training.') 133 training_parser.add_argument( 134 '--num_permutations', 135 default=10, 136 type=int, 137 help='The number of permutations per utterance sampled in the training ' 138 'data.') 139 training_parser.add_argument( 140 '--sigma_alpha', 141 default=1.0, 142 type=float, 143 help='The inverse gamma shape for estimating sigma2. This value is only ' 144 'meaningful when sigma2 is not given, and estimated from data.') 145 training_parser.add_argument( 146 '--sigma_beta', 147 default=1.0, 148 type=float, 149 help='The inverse gamma scale for estimating sigma2. This value is only ' 150 'meaningful when sigma2 is not given, and estimated from data.') 151 training_parser.add_argument( 152 '--regularization_weight', 153 '-r', 154 default=1e-5, 155 type=float, 156 help='The network regularization multiplicative.') 157 training_parser.add_argument( 158 '--grad_max_norm', 159 default=5.0, 160 type=float, 161 help='Max norm of the gradient.') 162 training_parser.add_argument( 163 '--enforce_cluster_id_uniqueness', 164 default=True, 165 type=str2bool, 166 help='Whether to enforce cluster ID uniqueness across different ' 167 'training sequences. Only effective when the first input to fit() ' 168 'is a list of sequences. In general, assume the cluster IDs for two ' 169 'sequences are [a, b] and [a, c]. If the `a` from the two sequences ' 170 'are not the same label, then this arg should be True.') 171 172 # inference configurations 173 inference_parser = argparse.ArgumentParser( 174 description='Inference configurations.', add_help=False) 175 176 inference_parser.add_argument( 177 '--beam_size', 178 '-s', 179 default=10, 180 type=int, 181 help='The beam search size for inference.') 182 inference_parser.add_argument( 183 '--look_ahead', 184 default=1, 185 type=int, 186 help='The number of look ahead steps during inference.') 187 inference_parser.add_argument( 188 '--test_iteration', 189 default=2, 190 type=int, 191 help='During inference, we concatenate M duplicates of the test ' 192 'sequence, and run inference on this concatenated sequence. ' 193 'Then we return the inference results on the last duplicate as the ' 194 'final prediction for the test sequence.') 195 196 # a super parser for sanity checks 197 super_parser = argparse.ArgumentParser( 198 parents=[model_parser, training_parser, inference_parser]) 199 200 # get arguments 201 super_parser.parse_args() 202 model_args, _ = model_parser.parse_known_args() 203 training_args, _ = training_parser.parse_known_args() 204 inference_args, _ = inference_parser.parse_known_args() 205 206 return (model_args, training_args, inference_args)
Parse arguments.
Returns: A tuple of:
- `model_args`: model arguments
- `training_args`: training arguments
- `inference_args`: inference arguments
41def compute_sequence_match_accuracy(sequence1, sequence2): 42 """Compute the accuracy between two sequences by finding optimal matching. 43 44 Args: 45 sequence1: A list of integers or strings. 46 sequence2: A list of integers or strings. 47 48 Returns: 49 accuracy: sequence matching accuracy as a number in [0.0, 1.0] 50 51 Raises: 52 TypeError: If sequence1 or sequence2 is not list. 53 ValueError: If sequence1 and sequence2 are not same size. 54 """ 55 if not isinstance(sequence1, list) or not isinstance(sequence2, list): 56 raise TypeError('sequence1 and sequence2 must be lists') 57 if not sequence1 or len(sequence1) != len(sequence2): 58 raise ValueError( 59 'sequence1 and sequence2 must have the same non-zero length') 60 # get unique ids from sequences 61 unique_ids1 = sorted(set(sequence1)) 62 unique_ids2 = sorted(set(sequence2)) 63 inverse_index1 = get_list_inverse_index(unique_ids1) 64 inverse_index2 = get_list_inverse_index(unique_ids2) 65 # get the count matrix 66 count_matrix = np.zeros((len(unique_ids1), len(unique_ids2))) 67 for item1, item2 in zip(sequence1, sequence2): 68 index1 = inverse_index1[item1] 69 index2 = inverse_index2[item2] 70 count_matrix[index1, index2] += 1.0 71 row_index, col_index = optimize.linear_sum_assignment(-count_matrix) 72 optimal_match_count = count_matrix[row_index, col_index].sum() 73 accuracy = optimal_match_count / len(sequence1) 74 return accuracy
Compute the accuracy between two sequences by finding optimal matching.
Args: sequence1: A list of integers or strings. sequence2: A list of integers or strings.
Returns: accuracy: sequence matching accuracy as a number in [0.0, 1.0]
Raises: TypeError: If sequence1 or sequence2 is not list. ValueError: If sequence1 and sequence2 are not same size.
254def output_result(model_args, training_args, test_record): 255 """Produce a string to summarize the experiment.""" 256 accuracy_array, _ = zip(*test_record) 257 total_accuracy = np.mean(accuracy_array) 258 output_string = """ 259Config: 260 sigma_alpha: {} 261 sigma_beta: {} 262 crp_alpha: {} 263 learning rate: {} 264 regularization: {} 265 batch size: {} 266 267Performance: 268 averaged accuracy: {:.6f} 269 accuracy numbers for all testing sequences: 270 """.strip().format( 271 training_args.sigma_alpha, 272 training_args.sigma_beta, 273 model_args.crp_alpha, 274 training_args.learning_rate, 275 training_args.regularization_weight, 276 training_args.batch_size, 277 total_accuracy) 278 for accuracy in accuracy_array: 279 output_string += '\n {:.6f}'.format(accuracy) 280 output_string += '\n' + '=' * 80 + '\n' 281 filename = 'layer_{}_{}_{:.1f}_result.txt'.format( 282 model_args.rnn_hidden_size, 283 model_args.rnn_depth, model_args.rnn_dropout) 284 with open(filename, 'a') as file_object: 285 file_object.write(output_string) 286 return output_string
Produce a string to summarize the experiment.
81class UISRNN: 82 """Unbounded Interleaved-State Recurrent Neural Networks.""" 83 84 def __init__(self, args): 85 """Construct the UISRNN object. 86 87 Args: 88 args: Model configurations. See `arguments.py` for details. 89 """ 90 self.observation_dim = args.observation_dim 91 self.device = torch.device( 92 'cuda:0' if (torch.cuda.is_available() and args.enable_cuda) else 'cpu') 93 self.rnn_model = CoreRNN(self.observation_dim, args.rnn_hidden_size, 94 args.rnn_depth, self.observation_dim, 95 args.rnn_dropout).to(self.device) 96 self.rnn_init_hidden = nn.Parameter( 97 torch.zeros(args.rnn_depth, 1, args.rnn_hidden_size).to(self.device)) 98 # booleans indicating which variables are trainable 99 self.estimate_sigma2 = (args.sigma2 is None) 100 self.estimate_transition_bias = (args.transition_bias is None) 101 # initial values of variables 102 sigma2 = _INITIAL_SIGMA2_VALUE if self.estimate_sigma2 else args.sigma2 103 self.sigma2 = nn.Parameter( 104 sigma2 * torch.ones(self.observation_dim).to(self.device)) 105 self.transition_bias = args.transition_bias 106 self.transition_bias_denominator = 0.0 107 self.crp_alpha = args.crp_alpha 108 self.logger = colortimelog.Logger(args.verbosity) 109 110 def _get_optimizer(self, optimizer, learning_rate): 111 """Get optimizer for UISRNN. 112 113 Args: 114 optimizer: string - name of the optimizer. 115 learning_rate: - learning rate for the entire model. 116 We do not customize learning rate for separate parts. 117 118 Returns: 119 a pytorch "optim" object 120 """ 121 params = [ 122 { 123 'params': self.rnn_model.parameters() 124 }, # rnn parameters 125 { 126 'params': self.rnn_init_hidden 127 } # rnn initial hidden state 128 ] 129 if self.estimate_sigma2: # train sigma2 130 params.append({ 131 'params': self.sigma2 132 }) # variance parameters 133 assert optimizer == 'adam', 'Only adam optimizer is supported.' 134 return optim.Adam(params, lr=learning_rate) 135 136 def save(self, filepath): 137 """Save the model to a file. 138 139 Args: 140 filepath: the path of the file. 141 """ 142 torch.save({ 143 'rnn_state_dict': self.rnn_model.state_dict(), 144 'rnn_init_hidden': self.rnn_init_hidden.detach().cpu().numpy(), 145 'transition_bias': self.transition_bias, 146 'transition_bias_denominator': self.transition_bias_denominator, 147 'crp_alpha': self.crp_alpha, 148 'sigma2': self.sigma2.detach().cpu().numpy()}, filepath) 149 150 def load(self, filepath): 151 """Load the model from a file. 152 153 Args: 154 filepath: the path of the file. 155 """ 156 var_dict = torch.load(filepath) 157 self.rnn_model.load_state_dict(var_dict['rnn_state_dict']) 158 self.rnn_init_hidden = nn.Parameter( 159 torch.from_numpy(var_dict['rnn_init_hidden']).to(self.device)) 160 self.transition_bias = float(var_dict['transition_bias']) 161 self.transition_bias_denominator = float( 162 var_dict['transition_bias_denominator']) 163 self.crp_alpha = float(var_dict['crp_alpha']) 164 self.sigma2 = nn.Parameter( 165 torch.from_numpy(var_dict['sigma2']).to(self.device)) 166 167 self.logger.print( 168 3, 'Loaded model with transition_bias={}, crp_alpha={}, sigma2={}, ' 169 'rnn_init_hidden={}'.format( 170 self.transition_bias, self.crp_alpha, var_dict['sigma2'], 171 var_dict['rnn_init_hidden'])) 172 173 def fit_concatenated(self, train_sequence, train_cluster_id, args): 174 """Fit UISRNN model to concatenated sequence and cluster_id. 175 176 Args: 177 train_sequence: the training observation sequence, which is a 178 2-dim numpy array of real numbers, of size `N * D`. 179 180 - `N`: summation of lengths of all utterances. 181 - `D`: observation dimension. 182 183 For example, 184 ``` 185 train_sequence = 186 [[1.2 3.0 -4.1 6.0] --> an entry of speaker #0 from utterance 'iaaa' 187 [0.8 -1.1 0.4 0.5] --> an entry of speaker #1 from utterance 'iaaa' 188 [-0.2 1.0 3.8 5.7] --> an entry of speaker #0 from utterance 'iaaa' 189 [3.8 -0.1 1.5 2.3] --> an entry of speaker #0 from utterance 'ibbb' 190 [1.2 1.4 3.6 -2.7]] --> an entry of speaker #0 from utterance 'ibbb' 191 ``` 192 Here `N=5`, `D=4`. 193 194 We concatenate all training utterances into this single sequence. 195 train_cluster_id: the speaker id sequence, which is 1-dim list or 196 numpy array of strings, of size `N`. 197 For example, 198 ``` 199 train_cluster_id = 200 ['iaaa_0', 'iaaa_1', 'iaaa_0', 'ibbb_0', 'ibbb_0'] 201 ``` 202 'iaaa_0' means the entry belongs to speaker #0 in utterance 'iaaa'. 203 204 Note that the order of entries within an utterance are preserved, 205 and all utterances are simply concatenated together. 206 args: Training configurations. See `arguments.py` for details. 207 208 Raises: 209 TypeError: If train_sequence or train_cluster_id is of wrong type. 210 ValueError: If train_sequence or train_cluster_id has wrong dimension. 211 """ 212 # check type 213 if (not isinstance(train_sequence, np.ndarray) or 214 train_sequence.dtype != float): 215 raise TypeError('train_sequence should be a numpy array of float type.') 216 if isinstance(train_cluster_id, list): 217 train_cluster_id = np.array(train_cluster_id) 218 if (not isinstance(train_cluster_id, np.ndarray) or 219 not train_cluster_id.dtype.name.startswith(('str', 'unicode'))): 220 raise TypeError('train_cluster_id type be a numpy array of strings.') 221 # check dimension 222 if train_sequence.ndim != 2: 223 raise ValueError('train_sequence must be 2-dim array.') 224 if train_cluster_id.ndim != 1: 225 raise ValueError('train_cluster_id must be 1-dim array.') 226 # check length and size 227 train_total_length, observation_dim = train_sequence.shape 228 if observation_dim != self.observation_dim: 229 raise ValueError('train_sequence does not match the dimension specified ' 230 'by args.observation_dim.') 231 if train_total_length != len(train_cluster_id): 232 raise ValueError('train_sequence length is not equal to ' 233 'train_cluster_id length.') 234 235 self.rnn_model.train() 236 optimizer = self._get_optimizer(optimizer=args.optimizer, 237 learning_rate=args.learning_rate) 238 239 sub_sequences, seq_lengths = utils.resize_sequence( 240 sequence=train_sequence, 241 cluster_id=train_cluster_id, 242 num_permutations=args.num_permutations) 243 244 # For batch learning, pack the entire dataset. 245 if args.batch_size is None: 246 packed_train_sequence, rnn_truth = utils.pack_sequence( 247 sub_sequences, 248 seq_lengths, 249 args.batch_size, 250 self.observation_dim, 251 self.device) 252 train_loss = [] 253 for num_iter in range(args.train_iteration): 254 optimizer.zero_grad() 255 # For online learning, pack a subset in each iteration. 256 if args.batch_size is not None: 257 packed_train_sequence, rnn_truth = utils.pack_sequence( 258 sub_sequences, 259 seq_lengths, 260 args.batch_size, 261 self.observation_dim, 262 self.device) 263 hidden = self.rnn_init_hidden.repeat(1, args.batch_size, 1) 264 mean, _ = self.rnn_model(packed_train_sequence, hidden) 265 # use mean to predict 266 mean = torch.cumsum(mean, dim=0) 267 mean_size = mean.size() 268 mean = torch.mm( 269 torch.diag( 270 1.0 / torch.arange(1, mean_size[0] + 1).float().to(self.device)), 271 mean.view(mean_size[0], -1)) 272 mean = mean.view(mean_size) 273 274 # Likelihood part. 275 loss1 = loss_func.weighted_mse_loss( 276 input_tensor=(rnn_truth != 0).float() * mean[:-1, :, :], 277 target_tensor=rnn_truth, 278 weight=1 / (2 * self.sigma2)) 279 280 # Sigma2 prior part. 281 weight = (((rnn_truth != 0).float() * mean[:-1, :, :] - rnn_truth) 282 ** 2).view(-1, observation_dim) 283 num_non_zero = torch.sum((weight != 0).float(), dim=0).squeeze() 284 loss2 = loss_func.sigma2_prior_loss( 285 num_non_zero, args.sigma_alpha, args.sigma_beta, self.sigma2) 286 287 # Regularization part. 288 loss3 = loss_func.regularization_loss( 289 self.rnn_model.parameters(), args.regularization_weight) 290 291 loss = loss1 + loss2 + loss3 292 loss.backward() 293 nn.utils.clip_grad_norm_(self.rnn_model.parameters(), args.grad_max_norm) 294 optimizer.step() 295 # avoid numerical issues 296 self.sigma2.data.clamp_(min=1e-6) 297 298 if (np.remainder(num_iter, 10) == 0 or 299 num_iter == args.train_iteration - 1): 300 self.logger.print( 301 2, 302 'Iter: {:d} \t' 303 'Training Loss: {:.4f} \n' 304 ' Negative Log Likelihood: {:.4f}\t' 305 'Sigma2 Prior: {:.4f}\t' 306 'Regularization: {:.4f}'.format( 307 num_iter, 308 float(loss.data), 309 float(loss1.data), 310 float(loss2.data), 311 float(loss3.data))) 312 train_loss.append(float(loss1.data)) # only save the likelihood part 313 self.logger.print( 314 1, 'Done training with {} iterations'.format(args.train_iteration)) 315 316 def fit(self, train_sequences, train_cluster_ids, args): 317 """Fit UISRNN model. 318 319 Args: 320 train_sequences: Either a list of training sequences, or a single 321 concatenated training sequence: 322 323 1. train_sequences is list, and each element is a 2-dim numpy array 324 of real numbers, of size: `length * D`. 325 The length varies among different sequences, but the D is the same. 326 In speaker diarization, each sequence is the sequence of speaker 327 embeddings of one utterance. 328 2. train_sequences is a single concatenated sequence, which is a 329 2-dim numpy array of real numbers. See `fit_concatenated()` 330 for more details. 331 train_cluster_ids: Ground truth labels for train_sequences: 332 333 1. if train_sequences is a list, this must also be a list of the same 334 size, each element being a 1-dim list or numpy array of strings. 335 2. if train_sequences is a single concatenated sequence, this 336 must also be the concatenated 1-dim list or numpy array of strings 337 args: Training configurations. See `arguments.py` for details. 338 339 Raises: 340 TypeError: If train_sequences or train_cluster_ids is of wrong type. 341 """ 342 if isinstance(train_sequences, np.ndarray): 343 # train_sequences is already the concatenated sequence 344 if self.estimate_transition_bias: 345 # see issue #55: https://github.com/google/uis-rnn/issues/55 346 self.logger.print( 347 2, 348 'Warning: transition_bias cannot be correctly estimated from a ' 349 'concatenated sequence; train_sequences will be treated as a ' 350 'single sequence. This can lead to inaccurate estimation of ' 351 'transition_bias. Please, consider estimating transition_bias ' 352 'before concatenating the sequences and passing it as argument.') 353 train_sequences = [train_sequences] 354 train_cluster_ids = [train_cluster_ids] 355 elif isinstance(train_sequences, list): 356 # train_sequences is a list of un-concatenated sequences 357 # we will concatenate it later, after estimating transition_bias 358 pass 359 else: 360 raise TypeError('train_sequences must be a list or numpy.ndarray') 361 362 # estimate transition_bias 363 if self.estimate_transition_bias: 364 (transition_bias, 365 transition_bias_denominator) = utils.estimate_transition_bias( 366 train_cluster_ids) 367 # set or update transition_bias 368 if self.transition_bias is None: 369 self.transition_bias = transition_bias 370 self.transition_bias_denominator = transition_bias_denominator 371 else: 372 self.transition_bias = ( 373 self.transition_bias * self.transition_bias_denominator + 374 transition_bias * transition_bias_denominator) / ( 375 self.transition_bias_denominator + transition_bias_denominator) 376 self.transition_bias_denominator += transition_bias_denominator 377 378 # concatenate train_sequences 379 (concatenated_train_sequence, 380 concatenated_train_cluster_id) = utils.concatenate_training_data( 381 train_sequences, 382 train_cluster_ids, 383 args.enforce_cluster_id_uniqueness, 384 True) 385 386 self.fit_concatenated( 387 concatenated_train_sequence, concatenated_train_cluster_id, args) 388 389 def _update_beam_state(self, beam_state, look_ahead_seq, cluster_seq): 390 """Update a beam state given a look ahead sequence and known cluster 391 assignments. 392 393 Args: 394 beam_state: A BeamState object. 395 look_ahead_seq: Look ahead sequence, size: look_ahead*D. 396 look_ahead: number of step to look ahead in the beam search. 397 D: observation dimension 398 cluster_seq: Cluster assignment sequence for look_ahead_seq. 399 400 Returns: 401 new_beam_state: An updated BeamState object. 402 """ 403 404 loss = 0 405 new_beam_state = BeamState(beam_state) 406 for sub_idx, cluster in enumerate(cluster_seq): 407 if cluster > len(new_beam_state.mean_set): # invalid trace 408 new_beam_state.neg_likelihood = float('inf') 409 break 410 elif cluster < len(new_beam_state.mean_set): # existing cluster 411 last_cluster = new_beam_state.trace[-1] 412 loss = loss_func.weighted_mse_loss( 413 input_tensor=torch.squeeze(new_beam_state.mean_set[cluster]), 414 target_tensor=look_ahead_seq[sub_idx, :], 415 weight=1 / (2 * self.sigma2)).cpu().detach().numpy() 416 if cluster == last_cluster: 417 loss -= np.log(1 - self.transition_bias) 418 else: 419 loss -= np.log(self.transition_bias) + np.log( 420 new_beam_state.block_counts[cluster]) - np.log( 421 sum(new_beam_state.block_counts) + self.crp_alpha) 422 # update new mean and new hidden 423 mean, hidden = self.rnn_model( 424 look_ahead_seq[sub_idx, :].unsqueeze(0).unsqueeze(0), 425 new_beam_state.hidden_set[cluster]) 426 new_beam_state.mean_set[cluster] = (new_beam_state.mean_set[cluster]*( 427 (np.array(new_beam_state.trace) == cluster).sum() - 428 1).astype(float) + mean.clone()) / ( 429 np.array(new_beam_state.trace) == cluster).sum().astype( 430 float) # use mean to predict 431 new_beam_state.hidden_set[cluster] = hidden.clone() 432 if cluster != last_cluster: 433 new_beam_state.block_counts[cluster] += 1 434 new_beam_state.trace.append(cluster) 435 else: # new cluster 436 init_input = autograd.Variable( 437 torch.zeros(self.observation_dim) 438 ).unsqueeze(0).unsqueeze(0).to(self.device) 439 mean, hidden = self.rnn_model(init_input, 440 self.rnn_init_hidden) 441 loss = loss_func.weighted_mse_loss( 442 input_tensor=torch.squeeze(mean), 443 target_tensor=look_ahead_seq[sub_idx, :], 444 weight=1 / (2 * self.sigma2)).cpu().detach().numpy() 445 loss -= np.log(self.transition_bias) + np.log( 446 self.crp_alpha) - np.log( 447 sum(new_beam_state.block_counts) + self.crp_alpha) 448 # update new min and new hidden 449 mean, hidden = self.rnn_model( 450 look_ahead_seq[sub_idx, :].unsqueeze(0).unsqueeze(0), 451 hidden) 452 new_beam_state.append(mean, hidden, cluster) 453 new_beam_state.neg_likelihood += loss 454 return new_beam_state 455 456 def _calculate_score(self, beam_state, look_ahead_seq): 457 """Calculate negative log likelihoods for all possible state allocations 458 of a look ahead sequence, according to the current beam state. 459 460 Args: 461 beam_state: A BeamState object. 462 look_ahead_seq: Look ahead sequence, size: look_ahead*D. 463 look_ahead: number of step to look ahead in the beam search. 464 D: observation dimension 465 466 Returns: 467 beam_score_set: a set of scores for each possible state allocation. 468 """ 469 470 look_ahead, _ = look_ahead_seq.shape 471 beam_num_clusters = len(beam_state.mean_set) 472 beam_score_set = float('inf') * np.ones( 473 beam_num_clusters + 1 + np.arange(look_ahead)) 474 for cluster_seq, _ in np.ndenumerate(beam_score_set): 475 updated_beam_state = self._update_beam_state(beam_state, 476 look_ahead_seq, cluster_seq) 477 beam_score_set[cluster_seq] = updated_beam_state.neg_likelihood 478 return beam_score_set 479 480 def predict_single(self, test_sequence, args): 481 """Predict labels for a single test sequence using UISRNN model. 482 483 Args: 484 test_sequence: the test observation sequence, which is 2-dim numpy array 485 of real numbers, of size `N * D`. 486 487 - `N`: length of one test utterance. 488 - `D` : observation dimension. 489 490 For example: 491 ``` 492 test_sequence = 493 [[2.2 -1.0 3.0 5.6] --> 1st entry of utterance 'iccc' 494 [0.5 1.8 -3.2 0.4] --> 2nd entry of utterance 'iccc' 495 [-2.2 5.0 1.8 3.7] --> 3rd entry of utterance 'iccc' 496 [-3.8 0.1 1.4 3.3] --> 4th entry of utterance 'iccc' 497 [0.1 2.7 3.5 -1.7]] --> 5th entry of utterance 'iccc' 498 ``` 499 Here `N=5`, `D=4`. 500 args: Inference configurations. See `arguments.py` for details. 501 502 Returns: 503 predicted_cluster_id: predicted speaker id sequence, which is 504 an array of integers, of size `N`. 505 For example, `predicted_cluster_id = [0, 1, 0, 0, 1]` 506 507 Raises: 508 TypeError: If test_sequence is of wrong type. 509 ValueError: If test_sequence has wrong dimension. 510 """ 511 # check type 512 if (not isinstance(test_sequence, np.ndarray) or 513 test_sequence.dtype != float): 514 raise TypeError('test_sequence should be a numpy array of float type.') 515 # check dimension 516 if test_sequence.ndim != 2: 517 raise ValueError('test_sequence must be 2-dim array.') 518 # check size 519 test_sequence_length, observation_dim = test_sequence.shape 520 if observation_dim != self.observation_dim: 521 raise ValueError('test_sequence does not match the dimension specified ' 522 'by args.observation_dim.') 523 524 self.rnn_model.eval() 525 test_sequence = np.tile(test_sequence, (args.test_iteration, 1)) 526 test_sequence = autograd.Variable( 527 torch.from_numpy(test_sequence).float()).to(self.device) 528 # bookkeeping for beam search 529 beam_set = [BeamState()] 530 for num_iter in np.arange(0, args.test_iteration * test_sequence_length, 531 args.look_ahead): 532 max_clusters = max([len(beam_state.mean_set) for beam_state in beam_set]) 533 look_ahead_seq = test_sequence[num_iter: num_iter + args.look_ahead, :] 534 look_ahead_seq_length = look_ahead_seq.shape[0] 535 score_set = float('inf') * np.ones( 536 np.append( 537 args.beam_size, max_clusters + 1 + np.arange( 538 look_ahead_seq_length))) 539 for beam_rank, beam_state in enumerate(beam_set): 540 beam_score_set = self._calculate_score(beam_state, look_ahead_seq) 541 score_set[beam_rank, :] = np.pad( 542 beam_score_set, 543 np.tile([[0, max_clusters - len(beam_state.mean_set)]], 544 (look_ahead_seq_length, 1)), 'constant', 545 constant_values=float('inf')) 546 # find top scores 547 score_ranked = np.sort(score_set, axis=None) 548 score_ranked[score_ranked == float('inf')] = 0 549 score_ranked = np.trim_zeros(score_ranked) 550 idx_ranked = np.argsort(score_set, axis=None) 551 updated_beam_set = [] 552 for new_beam_rank in range( 553 np.min((len(score_ranked), args.beam_size))): 554 total_idx = np.unravel_index(idx_ranked[new_beam_rank], 555 score_set.shape) 556 prev_beam_rank = total_idx[0].item() 557 cluster_seq = total_idx[1:] 558 updated_beam_state = self._update_beam_state( 559 beam_set[prev_beam_rank], look_ahead_seq, cluster_seq) 560 updated_beam_set.append(updated_beam_state) 561 beam_set = updated_beam_set 562 predicted_cluster_id = beam_set[0].trace[-test_sequence_length:] 563 return predicted_cluster_id 564 565 def predict(self, test_sequences, args): 566 """Predict labels for a single or many test sequences using UISRNN model. 567 568 Args: 569 test_sequences: Either a list of test sequences, or a single test 570 sequence. Each test sequence is a 2-dim numpy array 571 of real numbers. See `predict_single()` for details. 572 args: Inference configurations. See `arguments.py` for details. 573 574 Returns: 575 predicted_cluster_ids: Predicted labels for test_sequences. 576 577 1. if test_sequences is a list, predicted_cluster_ids will be a list 578 of the same size, where each element being a 1-dim list of strings. 579 2. if test_sequences is a single sequence, predicted_cluster_ids will 580 be a 1-dim list of strings 581 582 Raises: 583 TypeError: If test_sequences is of wrong type. 584 """ 585 # check type 586 if isinstance(test_sequences, np.ndarray): 587 return self.predict_single(test_sequences, args) 588 if isinstance(test_sequences, list): 589 return [self.predict_single(test_sequence, args) 590 for test_sequence in test_sequences] 591 raise TypeError('test_sequences should be either a list or numpy array.')
Unbounded Interleaved-State Recurrent Neural Networks.
84 def __init__(self, args): 85 """Construct the UISRNN object. 86 87 Args: 88 args: Model configurations. See `arguments.py` for details. 89 """ 90 self.observation_dim = args.observation_dim 91 self.device = torch.device( 92 'cuda:0' if (torch.cuda.is_available() and args.enable_cuda) else 'cpu') 93 self.rnn_model = CoreRNN(self.observation_dim, args.rnn_hidden_size, 94 args.rnn_depth, self.observation_dim, 95 args.rnn_dropout).to(self.device) 96 self.rnn_init_hidden = nn.Parameter( 97 torch.zeros(args.rnn_depth, 1, args.rnn_hidden_size).to(self.device)) 98 # booleans indicating which variables are trainable 99 self.estimate_sigma2 = (args.sigma2 is None) 100 self.estimate_transition_bias = (args.transition_bias is None) 101 # initial values of variables 102 sigma2 = _INITIAL_SIGMA2_VALUE if self.estimate_sigma2 else args.sigma2 103 self.sigma2 = nn.Parameter( 104 sigma2 * torch.ones(self.observation_dim).to(self.device)) 105 self.transition_bias = args.transition_bias 106 self.transition_bias_denominator = 0.0 107 self.crp_alpha = args.crp_alpha 108 self.logger = colortimelog.Logger(args.verbosity)
Construct the UISRNN object.
Args:
args: Model configurations. See arguments.py
for details.
136 def save(self, filepath): 137 """Save the model to a file. 138 139 Args: 140 filepath: the path of the file. 141 """ 142 torch.save({ 143 'rnn_state_dict': self.rnn_model.state_dict(), 144 'rnn_init_hidden': self.rnn_init_hidden.detach().cpu().numpy(), 145 'transition_bias': self.transition_bias, 146 'transition_bias_denominator': self.transition_bias_denominator, 147 'crp_alpha': self.crp_alpha, 148 'sigma2': self.sigma2.detach().cpu().numpy()}, filepath)
Save the model to a file.
Args: filepath: the path of the file.
150 def load(self, filepath): 151 """Load the model from a file. 152 153 Args: 154 filepath: the path of the file. 155 """ 156 var_dict = torch.load(filepath) 157 self.rnn_model.load_state_dict(var_dict['rnn_state_dict']) 158 self.rnn_init_hidden = nn.Parameter( 159 torch.from_numpy(var_dict['rnn_init_hidden']).to(self.device)) 160 self.transition_bias = float(var_dict['transition_bias']) 161 self.transition_bias_denominator = float( 162 var_dict['transition_bias_denominator']) 163 self.crp_alpha = float(var_dict['crp_alpha']) 164 self.sigma2 = nn.Parameter( 165 torch.from_numpy(var_dict['sigma2']).to(self.device)) 166 167 self.logger.print( 168 3, 'Loaded model with transition_bias={}, crp_alpha={}, sigma2={}, ' 169 'rnn_init_hidden={}'.format( 170 self.transition_bias, self.crp_alpha, var_dict['sigma2'], 171 var_dict['rnn_init_hidden']))
Load the model from a file.
Args: filepath: the path of the file.
173 def fit_concatenated(self, train_sequence, train_cluster_id, args): 174 """Fit UISRNN model to concatenated sequence and cluster_id. 175 176 Args: 177 train_sequence: the training observation sequence, which is a 178 2-dim numpy array of real numbers, of size `N * D`. 179 180 - `N`: summation of lengths of all utterances. 181 - `D`: observation dimension. 182 183 For example, 184 ``` 185 train_sequence = 186 [[1.2 3.0 -4.1 6.0] --> an entry of speaker #0 from utterance 'iaaa' 187 [0.8 -1.1 0.4 0.5] --> an entry of speaker #1 from utterance 'iaaa' 188 [-0.2 1.0 3.8 5.7] --> an entry of speaker #0 from utterance 'iaaa' 189 [3.8 -0.1 1.5 2.3] --> an entry of speaker #0 from utterance 'ibbb' 190 [1.2 1.4 3.6 -2.7]] --> an entry of speaker #0 from utterance 'ibbb' 191 ``` 192 Here `N=5`, `D=4`. 193 194 We concatenate all training utterances into this single sequence. 195 train_cluster_id: the speaker id sequence, which is 1-dim list or 196 numpy array of strings, of size `N`. 197 For example, 198 ``` 199 train_cluster_id = 200 ['iaaa_0', 'iaaa_1', 'iaaa_0', 'ibbb_0', 'ibbb_0'] 201 ``` 202 'iaaa_0' means the entry belongs to speaker #0 in utterance 'iaaa'. 203 204 Note that the order of entries within an utterance are preserved, 205 and all utterances are simply concatenated together. 206 args: Training configurations. See `arguments.py` for details. 207 208 Raises: 209 TypeError: If train_sequence or train_cluster_id is of wrong type. 210 ValueError: If train_sequence or train_cluster_id has wrong dimension. 211 """ 212 # check type 213 if (not isinstance(train_sequence, np.ndarray) or 214 train_sequence.dtype != float): 215 raise TypeError('train_sequence should be a numpy array of float type.') 216 if isinstance(train_cluster_id, list): 217 train_cluster_id = np.array(train_cluster_id) 218 if (not isinstance(train_cluster_id, np.ndarray) or 219 not train_cluster_id.dtype.name.startswith(('str', 'unicode'))): 220 raise TypeError('train_cluster_id type be a numpy array of strings.') 221 # check dimension 222 if train_sequence.ndim != 2: 223 raise ValueError('train_sequence must be 2-dim array.') 224 if train_cluster_id.ndim != 1: 225 raise ValueError('train_cluster_id must be 1-dim array.') 226 # check length and size 227 train_total_length, observation_dim = train_sequence.shape 228 if observation_dim != self.observation_dim: 229 raise ValueError('train_sequence does not match the dimension specified ' 230 'by args.observation_dim.') 231 if train_total_length != len(train_cluster_id): 232 raise ValueError('train_sequence length is not equal to ' 233 'train_cluster_id length.') 234 235 self.rnn_model.train() 236 optimizer = self._get_optimizer(optimizer=args.optimizer, 237 learning_rate=args.learning_rate) 238 239 sub_sequences, seq_lengths = utils.resize_sequence( 240 sequence=train_sequence, 241 cluster_id=train_cluster_id, 242 num_permutations=args.num_permutations) 243 244 # For batch learning, pack the entire dataset. 245 if args.batch_size is None: 246 packed_train_sequence, rnn_truth = utils.pack_sequence( 247 sub_sequences, 248 seq_lengths, 249 args.batch_size, 250 self.observation_dim, 251 self.device) 252 train_loss = [] 253 for num_iter in range(args.train_iteration): 254 optimizer.zero_grad() 255 # For online learning, pack a subset in each iteration. 256 if args.batch_size is not None: 257 packed_train_sequence, rnn_truth = utils.pack_sequence( 258 sub_sequences, 259 seq_lengths, 260 args.batch_size, 261 self.observation_dim, 262 self.device) 263 hidden = self.rnn_init_hidden.repeat(1, args.batch_size, 1) 264 mean, _ = self.rnn_model(packed_train_sequence, hidden) 265 # use mean to predict 266 mean = torch.cumsum(mean, dim=0) 267 mean_size = mean.size() 268 mean = torch.mm( 269 torch.diag( 270 1.0 / torch.arange(1, mean_size[0] + 1).float().to(self.device)), 271 mean.view(mean_size[0], -1)) 272 mean = mean.view(mean_size) 273 274 # Likelihood part. 275 loss1 = loss_func.weighted_mse_loss( 276 input_tensor=(rnn_truth != 0).float() * mean[:-1, :, :], 277 target_tensor=rnn_truth, 278 weight=1 / (2 * self.sigma2)) 279 280 # Sigma2 prior part. 281 weight = (((rnn_truth != 0).float() * mean[:-1, :, :] - rnn_truth) 282 ** 2).view(-1, observation_dim) 283 num_non_zero = torch.sum((weight != 0).float(), dim=0).squeeze() 284 loss2 = loss_func.sigma2_prior_loss( 285 num_non_zero, args.sigma_alpha, args.sigma_beta, self.sigma2) 286 287 # Regularization part. 288 loss3 = loss_func.regularization_loss( 289 self.rnn_model.parameters(), args.regularization_weight) 290 291 loss = loss1 + loss2 + loss3 292 loss.backward() 293 nn.utils.clip_grad_norm_(self.rnn_model.parameters(), args.grad_max_norm) 294 optimizer.step() 295 # avoid numerical issues 296 self.sigma2.data.clamp_(min=1e-6) 297 298 if (np.remainder(num_iter, 10) == 0 or 299 num_iter == args.train_iteration - 1): 300 self.logger.print( 301 2, 302 'Iter: {:d} \t' 303 'Training Loss: {:.4f} \n' 304 ' Negative Log Likelihood: {:.4f}\t' 305 'Sigma2 Prior: {:.4f}\t' 306 'Regularization: {:.4f}'.format( 307 num_iter, 308 float(loss.data), 309 float(loss1.data), 310 float(loss2.data), 311 float(loss3.data))) 312 train_loss.append(float(loss1.data)) # only save the likelihood part 313 self.logger.print( 314 1, 'Done training with {} iterations'.format(args.train_iteration))
Fit UISRNN model to concatenated sequence and cluster_id.
Args:
train_sequence: the training observation sequence, which is a
2-dim numpy array of real numbers, of size N * D
.
- `N`: summation of lengths of all utterances.
- `D`: observation dimension.
For example,
train_sequence =
[[1.2 3.0 -4.1 6.0] --> an entry of speaker #0 from utterance 'iaaa'
[0.8 -1.1 0.4 0.5] --> an entry of speaker #1 from utterance 'iaaa'
[-0.2 1.0 3.8 5.7] --> an entry of speaker #0 from utterance 'iaaa'
[3.8 -0.1 1.5 2.3] --> an entry of speaker #0 from utterance 'ibbb'
[1.2 1.4 3.6 -2.7]] --> an entry of speaker #0 from utterance 'ibbb'
Here `N=5`, `D=4`.
We concatenate all training utterances into this single sequence.
train_cluster_id: the speaker id sequence, which is 1-dim list or
numpy array of strings, of size N
.
For example,
train_cluster_id =
['iaaa_0', 'iaaa_1', 'iaaa_0', 'ibbb_0', 'ibbb_0']
'iaaa_0' means the entry belongs to speaker #0 in utterance 'iaaa'.
Note that the order of entries within an utterance are preserved,
and all utterances are simply concatenated together.
args: Training configurations. See arguments.py
for details.
Raises: TypeError: If train_sequence or train_cluster_id is of wrong type. ValueError: If train_sequence or train_cluster_id has wrong dimension.
316 def fit(self, train_sequences, train_cluster_ids, args): 317 """Fit UISRNN model. 318 319 Args: 320 train_sequences: Either a list of training sequences, or a single 321 concatenated training sequence: 322 323 1. train_sequences is list, and each element is a 2-dim numpy array 324 of real numbers, of size: `length * D`. 325 The length varies among different sequences, but the D is the same. 326 In speaker diarization, each sequence is the sequence of speaker 327 embeddings of one utterance. 328 2. train_sequences is a single concatenated sequence, which is a 329 2-dim numpy array of real numbers. See `fit_concatenated()` 330 for more details. 331 train_cluster_ids: Ground truth labels for train_sequences: 332 333 1. if train_sequences is a list, this must also be a list of the same 334 size, each element being a 1-dim list or numpy array of strings. 335 2. if train_sequences is a single concatenated sequence, this 336 must also be the concatenated 1-dim list or numpy array of strings 337 args: Training configurations. See `arguments.py` for details. 338 339 Raises: 340 TypeError: If train_sequences or train_cluster_ids is of wrong type. 341 """ 342 if isinstance(train_sequences, np.ndarray): 343 # train_sequences is already the concatenated sequence 344 if self.estimate_transition_bias: 345 # see issue #55: https://github.com/google/uis-rnn/issues/55 346 self.logger.print( 347 2, 348 'Warning: transition_bias cannot be correctly estimated from a ' 349 'concatenated sequence; train_sequences will be treated as a ' 350 'single sequence. This can lead to inaccurate estimation of ' 351 'transition_bias. Please, consider estimating transition_bias ' 352 'before concatenating the sequences and passing it as argument.') 353 train_sequences = [train_sequences] 354 train_cluster_ids = [train_cluster_ids] 355 elif isinstance(train_sequences, list): 356 # train_sequences is a list of un-concatenated sequences 357 # we will concatenate it later, after estimating transition_bias 358 pass 359 else: 360 raise TypeError('train_sequences must be a list or numpy.ndarray') 361 362 # estimate transition_bias 363 if self.estimate_transition_bias: 364 (transition_bias, 365 transition_bias_denominator) = utils.estimate_transition_bias( 366 train_cluster_ids) 367 # set or update transition_bias 368 if self.transition_bias is None: 369 self.transition_bias = transition_bias 370 self.transition_bias_denominator = transition_bias_denominator 371 else: 372 self.transition_bias = ( 373 self.transition_bias * self.transition_bias_denominator + 374 transition_bias * transition_bias_denominator) / ( 375 self.transition_bias_denominator + transition_bias_denominator) 376 self.transition_bias_denominator += transition_bias_denominator 377 378 # concatenate train_sequences 379 (concatenated_train_sequence, 380 concatenated_train_cluster_id) = utils.concatenate_training_data( 381 train_sequences, 382 train_cluster_ids, 383 args.enforce_cluster_id_uniqueness, 384 True) 385 386 self.fit_concatenated( 387 concatenated_train_sequence, concatenated_train_cluster_id, args)
Fit UISRNN model.
Args: train_sequences: Either a list of training sequences, or a single concatenated training sequence:
1. train_sequences is list, and each element is a 2-dim numpy array
of real numbers, of size: `length * D`.
The length varies among different sequences, but the D is the same.
In speaker diarization, each sequence is the sequence of speaker
embeddings of one utterance.
2. train_sequences is a single concatenated sequence, which is a
2-dim numpy array of real numbers. See `fit_concatenated()`
for more details.
train_cluster_ids: Ground truth labels for train_sequences:
1. if train_sequences is a list, this must also be a list of the same
size, each element being a 1-dim list or numpy array of strings.
2. if train_sequences is a single concatenated sequence, this
must also be the concatenated 1-dim list or numpy array of strings
args: Training configurations. See arguments.py
for details.
Raises: TypeError: If train_sequences or train_cluster_ids is of wrong type.
480 def predict_single(self, test_sequence, args): 481 """Predict labels for a single test sequence using UISRNN model. 482 483 Args: 484 test_sequence: the test observation sequence, which is 2-dim numpy array 485 of real numbers, of size `N * D`. 486 487 - `N`: length of one test utterance. 488 - `D` : observation dimension. 489 490 For example: 491 ``` 492 test_sequence = 493 [[2.2 -1.0 3.0 5.6] --> 1st entry of utterance 'iccc' 494 [0.5 1.8 -3.2 0.4] --> 2nd entry of utterance 'iccc' 495 [-2.2 5.0 1.8 3.7] --> 3rd entry of utterance 'iccc' 496 [-3.8 0.1 1.4 3.3] --> 4th entry of utterance 'iccc' 497 [0.1 2.7 3.5 -1.7]] --> 5th entry of utterance 'iccc' 498 ``` 499 Here `N=5`, `D=4`. 500 args: Inference configurations. See `arguments.py` for details. 501 502 Returns: 503 predicted_cluster_id: predicted speaker id sequence, which is 504 an array of integers, of size `N`. 505 For example, `predicted_cluster_id = [0, 1, 0, 0, 1]` 506 507 Raises: 508 TypeError: If test_sequence is of wrong type. 509 ValueError: If test_sequence has wrong dimension. 510 """ 511 # check type 512 if (not isinstance(test_sequence, np.ndarray) or 513 test_sequence.dtype != float): 514 raise TypeError('test_sequence should be a numpy array of float type.') 515 # check dimension 516 if test_sequence.ndim != 2: 517 raise ValueError('test_sequence must be 2-dim array.') 518 # check size 519 test_sequence_length, observation_dim = test_sequence.shape 520 if observation_dim != self.observation_dim: 521 raise ValueError('test_sequence does not match the dimension specified ' 522 'by args.observation_dim.') 523 524 self.rnn_model.eval() 525 test_sequence = np.tile(test_sequence, (args.test_iteration, 1)) 526 test_sequence = autograd.Variable( 527 torch.from_numpy(test_sequence).float()).to(self.device) 528 # bookkeeping for beam search 529 beam_set = [BeamState()] 530 for num_iter in np.arange(0, args.test_iteration * test_sequence_length, 531 args.look_ahead): 532 max_clusters = max([len(beam_state.mean_set) for beam_state in beam_set]) 533 look_ahead_seq = test_sequence[num_iter: num_iter + args.look_ahead, :] 534 look_ahead_seq_length = look_ahead_seq.shape[0] 535 score_set = float('inf') * np.ones( 536 np.append( 537 args.beam_size, max_clusters + 1 + np.arange( 538 look_ahead_seq_length))) 539 for beam_rank, beam_state in enumerate(beam_set): 540 beam_score_set = self._calculate_score(beam_state, look_ahead_seq) 541 score_set[beam_rank, :] = np.pad( 542 beam_score_set, 543 np.tile([[0, max_clusters - len(beam_state.mean_set)]], 544 (look_ahead_seq_length, 1)), 'constant', 545 constant_values=float('inf')) 546 # find top scores 547 score_ranked = np.sort(score_set, axis=None) 548 score_ranked[score_ranked == float('inf')] = 0 549 score_ranked = np.trim_zeros(score_ranked) 550 idx_ranked = np.argsort(score_set, axis=None) 551 updated_beam_set = [] 552 for new_beam_rank in range( 553 np.min((len(score_ranked), args.beam_size))): 554 total_idx = np.unravel_index(idx_ranked[new_beam_rank], 555 score_set.shape) 556 prev_beam_rank = total_idx[0].item() 557 cluster_seq = total_idx[1:] 558 updated_beam_state = self._update_beam_state( 559 beam_set[prev_beam_rank], look_ahead_seq, cluster_seq) 560 updated_beam_set.append(updated_beam_state) 561 beam_set = updated_beam_set 562 predicted_cluster_id = beam_set[0].trace[-test_sequence_length:] 563 return predicted_cluster_id
Predict labels for a single test sequence using UISRNN model.
Args:
test_sequence: the test observation sequence, which is 2-dim numpy array
of real numbers, of size N * D
.
- `N`: length of one test utterance.
- `D` : observation dimension.
For example:
test_sequence =
[[2.2 -1.0 3.0 5.6] --> 1st entry of utterance 'iccc'
[0.5 1.8 -3.2 0.4] --> 2nd entry of utterance 'iccc'
[-2.2 5.0 1.8 3.7] --> 3rd entry of utterance 'iccc'
[-3.8 0.1 1.4 3.3] --> 4th entry of utterance 'iccc'
[0.1 2.7 3.5 -1.7]] --> 5th entry of utterance 'iccc'
Here `N=5`, `D=4`.
args: Inference configurations. See arguments.py
for details.
Returns:
predicted_cluster_id: predicted speaker id sequence, which is
an array of integers, of size N
.
For example, predicted_cluster_id = [0, 1, 0, 0, 1]
Raises: TypeError: If test_sequence is of wrong type. ValueError: If test_sequence has wrong dimension.
565 def predict(self, test_sequences, args): 566 """Predict labels for a single or many test sequences using UISRNN model. 567 568 Args: 569 test_sequences: Either a list of test sequences, or a single test 570 sequence. Each test sequence is a 2-dim numpy array 571 of real numbers. See `predict_single()` for details. 572 args: Inference configurations. See `arguments.py` for details. 573 574 Returns: 575 predicted_cluster_ids: Predicted labels for test_sequences. 576 577 1. if test_sequences is a list, predicted_cluster_ids will be a list 578 of the same size, where each element being a 1-dim list of strings. 579 2. if test_sequences is a single sequence, predicted_cluster_ids will 580 be a 1-dim list of strings 581 582 Raises: 583 TypeError: If test_sequences is of wrong type. 584 """ 585 # check type 586 if isinstance(test_sequences, np.ndarray): 587 return self.predict_single(test_sequences, args) 588 if isinstance(test_sequences, list): 589 return [self.predict_single(test_sequence, args) 590 for test_sequence in test_sequences] 591 raise TypeError('test_sequences should be either a list or numpy array.')
Predict labels for a single or many test sequences using UISRNN model.
Args:
test_sequences: Either a list of test sequences, or a single test
sequence. Each test sequence is a 2-dim numpy array
of real numbers. See predict_single()
for details.
args: Inference configurations. See arguments.py
for details.
Returns: predicted_cluster_ids: Predicted labels for test_sequences.
1. if test_sequences is a list, predicted_cluster_ids will be a list
of the same size, where each element being a 1-dim list of strings.
2. if test_sequences is a single sequence, predicted_cluster_ids will
be a 1-dim list of strings
Raises: TypeError: If test_sequences is of wrong type.
594def parallel_predict(model, test_sequences, args, num_processes=4): 595 """Run prediction in parallel using torch.multiprocessing. 596 597 This is a beta feature. It makes prediction slower on CPU. But it's reported 598 that it makes prediction faster on GPU. 599 600 Args: 601 model: instance of UISRNN model 602 test_sequences: a list of test sequences, or a single test 603 sequence. Each test sequence is a 2-dim numpy array 604 of real numbers. See `predict_single()` for details. 605 args: Inference configurations. See `arguments.py` for details. 606 num_processes: number of parallel processes. 607 608 Returns: 609 a list of the same size as test_sequences, where each element 610 being a 1-dim list of strings. 611 612 Raises: 613 TypeError: If test_sequences is of wrong type. 614 """ 615 if not isinstance(test_sequences, list): 616 raise TypeError('test_sequences must be a list.') 617 ctx = multiprocessing.get_context('forkserver') 618 model.rnn_model.share_memory() 619 pool = ctx.Pool(num_processes) 620 results = pool.map( 621 functools.partial(model.predict_single, args=args), 622 test_sequences) 623 pool.close() 624 return results
Run prediction in parallel using torch.multiprocessing.
This is a beta feature. It makes prediction slower on CPU. But it's reported that it makes prediction faster on GPU.
Args:
model: instance of UISRNN model
test_sequences: a list of test sequences, or a single test
sequence. Each test sequence is a 2-dim numpy array
of real numbers. See predict_single()
for details.
args: Inference configurations. See arguments.py
for details.
num_processes: number of parallel processes.
Returns: a list of the same size as test_sequences, where each element being a 1-dim list of strings.
Raises: TypeError: If test_sequences is of wrong type.