uisrnn.utils
Utils for UIS-RNN.
1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Utils for UIS-RNN.""" 15 16import random 17import string 18 19import numpy as np 20import torch 21from torch import autograd 22 23 24def generate_random_string(length=6): 25 """Generate a random string of upper case letters and digits. 26 27 Args: 28 length: length of the generated string 29 30 Returns: 31 the generated string 32 """ 33 return ''.join([ 34 random.choice(string.ascii_uppercase + string.digits) 35 for _ in range(length)]) 36 37 38def enforce_cluster_id_uniqueness(cluster_ids): 39 """Enforce uniqueness of cluster id across sequences. 40 41 Args: 42 cluster_ids: a list of 1-dim list/numpy.ndarray of strings 43 44 Returns: 45 a new list with same length of cluster_ids 46 47 Raises: 48 TypeError: if cluster_ids or its element has wrong type 49 """ 50 if not isinstance(cluster_ids, list): 51 raise TypeError('cluster_ids must be a list') 52 new_cluster_ids = [] 53 for cluster_id in cluster_ids: 54 sequence_id = generate_random_string() 55 if isinstance(cluster_id, np.ndarray): 56 cluster_id = cluster_id.tolist() 57 if not isinstance(cluster_id, list): 58 raise TypeError('Elements of cluster_ids must be list or numpy.ndarray') 59 new_cluster_id = ['_'.join([sequence_id, s]) for s in cluster_id] 60 new_cluster_ids.append(new_cluster_id) 61 return new_cluster_ids 62 63 64def concatenate_training_data(train_sequences, train_cluster_ids, 65 enforce_uniqueness=True, shuffle=True): 66 """Concatenate training data. 67 68 Args: 69 train_sequences: a list of 2-dim numpy arrays to be concatenated 70 train_cluster_ids: a list of 1-dim list/numpy.ndarray of strings 71 enforce_uniqueness: a boolean indicated whether we should enfore uniqueness 72 to train_cluster_ids 73 shuffle: whether to randomly shuffle input order 74 75 Returns: 76 concatenated_train_sequence: a 2-dim numpy array 77 concatenated_train_cluster_id: a list of strings 78 79 Raises: 80 TypeError: if input has wrong type 81 ValueError: if sizes/dimensions of input or their elements are incorrect 82 """ 83 # check input 84 if not isinstance(train_sequences, list) or not isinstance( 85 train_cluster_ids, list): 86 raise TypeError('train_sequences and train_cluster_ids must be lists') 87 if len(train_sequences) != len(train_cluster_ids): 88 raise ValueError( 89 'train_sequences and train_cluster_ids must have same size') 90 train_cluster_ids = [ 91 x.tolist() if isinstance(x, np.ndarray) else x 92 for x in train_cluster_ids] 93 global_observation_dim = None 94 for i, (train_sequence, train_cluster_id) in enumerate( 95 zip(train_sequences, train_cluster_ids)): 96 train_length, observation_dim = train_sequence.shape 97 if i == 0: 98 global_observation_dim = observation_dim 99 elif global_observation_dim != observation_dim: 100 raise ValueError( 101 'train_sequences must have consistent observation dimension') 102 if not isinstance(train_cluster_id, list): 103 raise TypeError( 104 'Elements of train_cluster_ids must be list or numpy.ndarray') 105 if len(train_cluster_id) != train_length: 106 raise ValueError( 107 'Each train_sequence and its train_cluster_id must have same length') 108 109 # enforce uniqueness 110 if enforce_uniqueness: 111 train_cluster_ids = enforce_cluster_id_uniqueness(train_cluster_ids) 112 113 # random shuffle 114 if shuffle: 115 zipped_input = list(zip(train_sequences, train_cluster_ids)) 116 random.shuffle(zipped_input) 117 train_sequences, train_cluster_ids = zip(*zipped_input) 118 119 # concatenate 120 concatenated_train_sequence = np.concatenate(train_sequences, axis=0) 121 concatenated_train_cluster_id = [x for train_cluster_id in train_cluster_ids 122 for x in train_cluster_id] 123 return concatenated_train_sequence, concatenated_train_cluster_id 124 125 126def sample_permuted_segments(index_sequence, number_samples): 127 """Sample sequences with permuted blocks. 128 129 Args: 130 index_sequence: (integer array, size: L) 131 - subsequence index 132 For example, index_sequence = [1,2,6,10,11,12]. 133 number_samples: (integer) 134 - number of subsampled block-preserving permuted sequences. 135 For example, number_samples = 5 136 137 Returns: 138 sampled_index_sequences: (a list of numpy arrays) - a list of subsampled 139 block-preserving permuted sequences. For example, 140 ``` 141 sampled_index_sequences = 142 [[10,11,12,1,2,6], 143 [6,1,2,10,11,12], 144 [1,2,10,11,12,6], 145 [6,1,2,10,11,12], 146 [1,2,6,10,11,12]] 147 ``` 148 The length of "sampled_index_sequences" is "number_samples". 149 """ 150 segments = [] 151 if len(index_sequence) == 1: 152 segments.append(index_sequence) 153 else: 154 prev = 0 155 for i in range(len(index_sequence) - 1): 156 if index_sequence[i + 1] != index_sequence[i] + 1: 157 segments.append(index_sequence[prev:(i + 1)]) 158 prev = i + 1 159 if i + 1 == len(index_sequence) - 1: 160 segments.append(index_sequence[prev:]) 161 # sample permutations 162 sampled_index_sequences = [] 163 for _ in range(number_samples): 164 segments_array = [] 165 permutation = np.random.permutation(len(segments)) 166 for permutation_item in permutation: 167 segments_array.append(segments[permutation_item]) 168 sampled_index_sequences.append(np.concatenate(segments_array)) 169 return sampled_index_sequences 170 171 172def resize_sequence(sequence, cluster_id, num_permutations=None): 173 """Resize sequences for packing and batching. 174 175 Args: 176 sequence: (real numpy matrix, size: seq_len*obs_size) - observed sequence 177 cluster_id: (numpy vector, size: seq_len) - cluster indicator sequence 178 num_permutations: int - Number of permutations per utterance sampled. 179 180 Returns: 181 sub_sequences: A list of numpy array, with obsevation vector from the same 182 cluster in the same list. 183 seq_lengths: The length of each cluster (+1). 184 """ 185 # merge sub-sequences that belong to a single cluster to a single sequence 186 unique_id = np.unique(cluster_id) 187 sub_sequences = [] 188 seq_lengths = [] 189 if num_permutations and num_permutations > 1: 190 for i in unique_id: 191 idx_set = np.where(cluster_id == i)[0] 192 sampled_idx_sets = sample_permuted_segments(idx_set, num_permutations) 193 for j in range(num_permutations): 194 sub_sequences.append(sequence[sampled_idx_sets[j], :]) 195 seq_lengths.append(len(idx_set) + 1) 196 else: 197 for i in unique_id: 198 idx_set = np.where(cluster_id == i) 199 sub_sequences.append(sequence[idx_set, :][0]) 200 seq_lengths.append(len(idx_set[0]) + 1) 201 return sub_sequences, seq_lengths 202 203 204def pack_sequence( 205 sub_sequences, seq_lengths, batch_size, observation_dim, device): 206 """Pack sequences for training. 207 208 Args: 209 sub_sequences: A list of numpy array, with obsevation vector from the same 210 cluster in the same list. 211 seq_lengths: The length of each cluster (+1). 212 batch_size: int or None - Run batch learning if batch_size is None. Else, 213 run online learning with specified batch size. 214 observation_dim: int - dimension for observation vectors 215 device: str - Your device. E.g., `cuda:0` or `cpu`. 216 217 Returns: 218 packed_rnn_input: (PackedSequence object) packed rnn input 219 rnn_truth: ground truth 220 """ 221 num_clusters = len(seq_lengths) 222 sorted_seq_lengths = np.sort(seq_lengths)[::-1] 223 permute_index = np.argsort(seq_lengths)[::-1] 224 225 if batch_size is None: 226 rnn_input = np.zeros((sorted_seq_lengths[0], 227 num_clusters, 228 observation_dim)) 229 for i in range(num_clusters): 230 rnn_input[1:sorted_seq_lengths[i], i, 231 :] = sub_sequences[permute_index[i]] 232 rnn_input = autograd.Variable( 233 torch.from_numpy(rnn_input).float()).to(device) 234 packed_rnn_input = torch.nn.utils.rnn.pack_padded_sequence( 235 rnn_input, sorted_seq_lengths, batch_first=False) 236 else: 237 mini_batch = np.sort(np.random.choice(num_clusters, batch_size)) 238 rnn_input = np.zeros((sorted_seq_lengths[mini_batch[0]], 239 batch_size, 240 observation_dim)) 241 for i in range(batch_size): 242 rnn_input[1:sorted_seq_lengths[mini_batch[i]], 243 i, :] = sub_sequences[permute_index[mini_batch[i]]] 244 rnn_input = autograd.Variable( 245 torch.from_numpy(rnn_input).float()).to(device) 246 packed_rnn_input = torch.nn.utils.rnn.pack_padded_sequence( 247 rnn_input, sorted_seq_lengths[mini_batch], batch_first=False) 248 # ground truth is the shifted input 249 rnn_truth = rnn_input[1:, :, :] 250 return packed_rnn_input, rnn_truth 251 252 253def output_result(model_args, training_args, test_record): 254 """Produce a string to summarize the experiment.""" 255 accuracy_array, _ = zip(*test_record) 256 total_accuracy = np.mean(accuracy_array) 257 output_string = """ 258Config: 259 sigma_alpha: {} 260 sigma_beta: {} 261 crp_alpha: {} 262 learning rate: {} 263 regularization: {} 264 batch size: {} 265 266Performance: 267 averaged accuracy: {:.6f} 268 accuracy numbers for all testing sequences: 269 """.strip().format( 270 training_args.sigma_alpha, 271 training_args.sigma_beta, 272 model_args.crp_alpha, 273 training_args.learning_rate, 274 training_args.regularization_weight, 275 training_args.batch_size, 276 total_accuracy) 277 for accuracy in accuracy_array: 278 output_string += '\n {:.6f}'.format(accuracy) 279 output_string += '\n' + '=' * 80 + '\n' 280 filename = 'layer_{}_{}_{:.1f}_result.txt'.format( 281 model_args.rnn_hidden_size, 282 model_args.rnn_depth, model_args.rnn_dropout) 283 with open(filename, 'a') as file_object: 284 file_object.write(output_string) 285 return output_string 286 287 288def estimate_transition_bias(cluster_ids, smooth=1): 289 """Estimate the transition bias. 290 291 Args: 292 cluster_id: Either a list of cluster indicator sequences, or a single 293 concatenated sequence. The former is strongly preferred, since the 294 transition_bias estimated from the latter will be inaccurate. 295 smooth: int or float - Smoothing coefficient, avoids -inf value in np.log 296 in the case of a sequence with a single speaker and division by 0 in the 297 case of empty sequences. Using a small value for smooth decreases the 298 bias in the calculation of transition_bias but can also lead to underflow 299 in some remote cases, larger values are safer but less accurate. 300 301 Returns: 302 bias: Flipping coin head probability. 303 bias_denominator: The denominator of the bias, used for multiple calls to 304 fit(). 305 """ 306 transit_num = smooth 307 bias_denominator = 2 * smooth 308 for cluster_id_seq in cluster_ids: 309 for entry in range(len(cluster_id_seq) - 1): 310 transit_num += (cluster_id_seq[entry] != cluster_id_seq[entry + 1]) 311 bias_denominator += 1 312 bias = transit_num / bias_denominator 313 return bias, bias_denominator
25def generate_random_string(length=6): 26 """Generate a random string of upper case letters and digits. 27 28 Args: 29 length: length of the generated string 30 31 Returns: 32 the generated string 33 """ 34 return ''.join([ 35 random.choice(string.ascii_uppercase + string.digits) 36 for _ in range(length)])
Generate a random string of upper case letters and digits.
Args: length: length of the generated string
Returns: the generated string
39def enforce_cluster_id_uniqueness(cluster_ids): 40 """Enforce uniqueness of cluster id across sequences. 41 42 Args: 43 cluster_ids: a list of 1-dim list/numpy.ndarray of strings 44 45 Returns: 46 a new list with same length of cluster_ids 47 48 Raises: 49 TypeError: if cluster_ids or its element has wrong type 50 """ 51 if not isinstance(cluster_ids, list): 52 raise TypeError('cluster_ids must be a list') 53 new_cluster_ids = [] 54 for cluster_id in cluster_ids: 55 sequence_id = generate_random_string() 56 if isinstance(cluster_id, np.ndarray): 57 cluster_id = cluster_id.tolist() 58 if not isinstance(cluster_id, list): 59 raise TypeError('Elements of cluster_ids must be list or numpy.ndarray') 60 new_cluster_id = ['_'.join([sequence_id, s]) for s in cluster_id] 61 new_cluster_ids.append(new_cluster_id) 62 return new_cluster_ids
Enforce uniqueness of cluster id across sequences.
Args: cluster_ids: a list of 1-dim list/numpy.ndarray of strings
Returns: a new list with same length of cluster_ids
Raises: TypeError: if cluster_ids or its element has wrong type
65def concatenate_training_data(train_sequences, train_cluster_ids, 66 enforce_uniqueness=True, shuffle=True): 67 """Concatenate training data. 68 69 Args: 70 train_sequences: a list of 2-dim numpy arrays to be concatenated 71 train_cluster_ids: a list of 1-dim list/numpy.ndarray of strings 72 enforce_uniqueness: a boolean indicated whether we should enfore uniqueness 73 to train_cluster_ids 74 shuffle: whether to randomly shuffle input order 75 76 Returns: 77 concatenated_train_sequence: a 2-dim numpy array 78 concatenated_train_cluster_id: a list of strings 79 80 Raises: 81 TypeError: if input has wrong type 82 ValueError: if sizes/dimensions of input or their elements are incorrect 83 """ 84 # check input 85 if not isinstance(train_sequences, list) or not isinstance( 86 train_cluster_ids, list): 87 raise TypeError('train_sequences and train_cluster_ids must be lists') 88 if len(train_sequences) != len(train_cluster_ids): 89 raise ValueError( 90 'train_sequences and train_cluster_ids must have same size') 91 train_cluster_ids = [ 92 x.tolist() if isinstance(x, np.ndarray) else x 93 for x in train_cluster_ids] 94 global_observation_dim = None 95 for i, (train_sequence, train_cluster_id) in enumerate( 96 zip(train_sequences, train_cluster_ids)): 97 train_length, observation_dim = train_sequence.shape 98 if i == 0: 99 global_observation_dim = observation_dim 100 elif global_observation_dim != observation_dim: 101 raise ValueError( 102 'train_sequences must have consistent observation dimension') 103 if not isinstance(train_cluster_id, list): 104 raise TypeError( 105 'Elements of train_cluster_ids must be list or numpy.ndarray') 106 if len(train_cluster_id) != train_length: 107 raise ValueError( 108 'Each train_sequence and its train_cluster_id must have same length') 109 110 # enforce uniqueness 111 if enforce_uniqueness: 112 train_cluster_ids = enforce_cluster_id_uniqueness(train_cluster_ids) 113 114 # random shuffle 115 if shuffle: 116 zipped_input = list(zip(train_sequences, train_cluster_ids)) 117 random.shuffle(zipped_input) 118 train_sequences, train_cluster_ids = zip(*zipped_input) 119 120 # concatenate 121 concatenated_train_sequence = np.concatenate(train_sequences, axis=0) 122 concatenated_train_cluster_id = [x for train_cluster_id in train_cluster_ids 123 for x in train_cluster_id] 124 return concatenated_train_sequence, concatenated_train_cluster_id
Concatenate training data.
Args: train_sequences: a list of 2-dim numpy arrays to be concatenated train_cluster_ids: a list of 1-dim list/numpy.ndarray of strings enforce_uniqueness: a boolean indicated whether we should enfore uniqueness to train_cluster_ids shuffle: whether to randomly shuffle input order
Returns: concatenated_train_sequence: a 2-dim numpy array concatenated_train_cluster_id: a list of strings
Raises: TypeError: if input has wrong type ValueError: if sizes/dimensions of input or their elements are incorrect
127def sample_permuted_segments(index_sequence, number_samples): 128 """Sample sequences with permuted blocks. 129 130 Args: 131 index_sequence: (integer array, size: L) 132 - subsequence index 133 For example, index_sequence = [1,2,6,10,11,12]. 134 number_samples: (integer) 135 - number of subsampled block-preserving permuted sequences. 136 For example, number_samples = 5 137 138 Returns: 139 sampled_index_sequences: (a list of numpy arrays) - a list of subsampled 140 block-preserving permuted sequences. For example, 141 ``` 142 sampled_index_sequences = 143 [[10,11,12,1,2,6], 144 [6,1,2,10,11,12], 145 [1,2,10,11,12,6], 146 [6,1,2,10,11,12], 147 [1,2,6,10,11,12]] 148 ``` 149 The length of "sampled_index_sequences" is "number_samples". 150 """ 151 segments = [] 152 if len(index_sequence) == 1: 153 segments.append(index_sequence) 154 else: 155 prev = 0 156 for i in range(len(index_sequence) - 1): 157 if index_sequence[i + 1] != index_sequence[i] + 1: 158 segments.append(index_sequence[prev:(i + 1)]) 159 prev = i + 1 160 if i + 1 == len(index_sequence) - 1: 161 segments.append(index_sequence[prev:]) 162 # sample permutations 163 sampled_index_sequences = [] 164 for _ in range(number_samples): 165 segments_array = [] 166 permutation = np.random.permutation(len(segments)) 167 for permutation_item in permutation: 168 segments_array.append(segments[permutation_item]) 169 sampled_index_sequences.append(np.concatenate(segments_array)) 170 return sampled_index_sequences
Sample sequences with permuted blocks.
Args: index_sequence: (integer array, size: L) - subsequence index For example, index_sequence = [1,2,6,10,11,12]. number_samples: (integer) - number of subsampled block-preserving permuted sequences. For example, number_samples = 5
Returns: sampled_index_sequences: (a list of numpy arrays) - a list of subsampled block-preserving permuted sequences. For example,
sampled_index_sequences =
[[10,11,12,1,2,6],
[6,1,2,10,11,12],
[1,2,10,11,12,6],
[6,1,2,10,11,12],
[1,2,6,10,11,12]]
The length of "sampled_index_sequences" is "number_samples".
173def resize_sequence(sequence, cluster_id, num_permutations=None): 174 """Resize sequences for packing and batching. 175 176 Args: 177 sequence: (real numpy matrix, size: seq_len*obs_size) - observed sequence 178 cluster_id: (numpy vector, size: seq_len) - cluster indicator sequence 179 num_permutations: int - Number of permutations per utterance sampled. 180 181 Returns: 182 sub_sequences: A list of numpy array, with obsevation vector from the same 183 cluster in the same list. 184 seq_lengths: The length of each cluster (+1). 185 """ 186 # merge sub-sequences that belong to a single cluster to a single sequence 187 unique_id = np.unique(cluster_id) 188 sub_sequences = [] 189 seq_lengths = [] 190 if num_permutations and num_permutations > 1: 191 for i in unique_id: 192 idx_set = np.where(cluster_id == i)[0] 193 sampled_idx_sets = sample_permuted_segments(idx_set, num_permutations) 194 for j in range(num_permutations): 195 sub_sequences.append(sequence[sampled_idx_sets[j], :]) 196 seq_lengths.append(len(idx_set) + 1) 197 else: 198 for i in unique_id: 199 idx_set = np.where(cluster_id == i) 200 sub_sequences.append(sequence[idx_set, :][0]) 201 seq_lengths.append(len(idx_set[0]) + 1) 202 return sub_sequences, seq_lengths
Resize sequences for packing and batching.
Args: sequence: (real numpy matrix, size: seq_len*obs_size) - observed sequence cluster_id: (numpy vector, size: seq_len) - cluster indicator sequence num_permutations: int - Number of permutations per utterance sampled.
Returns: sub_sequences: A list of numpy array, with obsevation vector from the same cluster in the same list. seq_lengths: The length of each cluster (+1).
205def pack_sequence( 206 sub_sequences, seq_lengths, batch_size, observation_dim, device): 207 """Pack sequences for training. 208 209 Args: 210 sub_sequences: A list of numpy array, with obsevation vector from the same 211 cluster in the same list. 212 seq_lengths: The length of each cluster (+1). 213 batch_size: int or None - Run batch learning if batch_size is None. Else, 214 run online learning with specified batch size. 215 observation_dim: int - dimension for observation vectors 216 device: str - Your device. E.g., `cuda:0` or `cpu`. 217 218 Returns: 219 packed_rnn_input: (PackedSequence object) packed rnn input 220 rnn_truth: ground truth 221 """ 222 num_clusters = len(seq_lengths) 223 sorted_seq_lengths = np.sort(seq_lengths)[::-1] 224 permute_index = np.argsort(seq_lengths)[::-1] 225 226 if batch_size is None: 227 rnn_input = np.zeros((sorted_seq_lengths[0], 228 num_clusters, 229 observation_dim)) 230 for i in range(num_clusters): 231 rnn_input[1:sorted_seq_lengths[i], i, 232 :] = sub_sequences[permute_index[i]] 233 rnn_input = autograd.Variable( 234 torch.from_numpy(rnn_input).float()).to(device) 235 packed_rnn_input = torch.nn.utils.rnn.pack_padded_sequence( 236 rnn_input, sorted_seq_lengths, batch_first=False) 237 else: 238 mini_batch = np.sort(np.random.choice(num_clusters, batch_size)) 239 rnn_input = np.zeros((sorted_seq_lengths[mini_batch[0]], 240 batch_size, 241 observation_dim)) 242 for i in range(batch_size): 243 rnn_input[1:sorted_seq_lengths[mini_batch[i]], 244 i, :] = sub_sequences[permute_index[mini_batch[i]]] 245 rnn_input = autograd.Variable( 246 torch.from_numpy(rnn_input).float()).to(device) 247 packed_rnn_input = torch.nn.utils.rnn.pack_padded_sequence( 248 rnn_input, sorted_seq_lengths[mini_batch], batch_first=False) 249 # ground truth is the shifted input 250 rnn_truth = rnn_input[1:, :, :] 251 return packed_rnn_input, rnn_truth
Pack sequences for training.
Args:
sub_sequences: A list of numpy array, with obsevation vector from the same
cluster in the same list.
seq_lengths: The length of each cluster (+1).
batch_size: int or None - Run batch learning if batch_size is None. Else,
run online learning with specified batch size.
observation_dim: int - dimension for observation vectors
device: str - Your device. E.g., cuda:0
or cpu
.
Returns: packed_rnn_input: (PackedSequence object) packed rnn input rnn_truth: ground truth
254def output_result(model_args, training_args, test_record): 255 """Produce a string to summarize the experiment.""" 256 accuracy_array, _ = zip(*test_record) 257 total_accuracy = np.mean(accuracy_array) 258 output_string = """ 259Config: 260 sigma_alpha: {} 261 sigma_beta: {} 262 crp_alpha: {} 263 learning rate: {} 264 regularization: {} 265 batch size: {} 266 267Performance: 268 averaged accuracy: {:.6f} 269 accuracy numbers for all testing sequences: 270 """.strip().format( 271 training_args.sigma_alpha, 272 training_args.sigma_beta, 273 model_args.crp_alpha, 274 training_args.learning_rate, 275 training_args.regularization_weight, 276 training_args.batch_size, 277 total_accuracy) 278 for accuracy in accuracy_array: 279 output_string += '\n {:.6f}'.format(accuracy) 280 output_string += '\n' + '=' * 80 + '\n' 281 filename = 'layer_{}_{}_{:.1f}_result.txt'.format( 282 model_args.rnn_hidden_size, 283 model_args.rnn_depth, model_args.rnn_dropout) 284 with open(filename, 'a') as file_object: 285 file_object.write(output_string) 286 return output_string
Produce a string to summarize the experiment.
289def estimate_transition_bias(cluster_ids, smooth=1): 290 """Estimate the transition bias. 291 292 Args: 293 cluster_id: Either a list of cluster indicator sequences, or a single 294 concatenated sequence. The former is strongly preferred, since the 295 transition_bias estimated from the latter will be inaccurate. 296 smooth: int or float - Smoothing coefficient, avoids -inf value in np.log 297 in the case of a sequence with a single speaker and division by 0 in the 298 case of empty sequences. Using a small value for smooth decreases the 299 bias in the calculation of transition_bias but can also lead to underflow 300 in some remote cases, larger values are safer but less accurate. 301 302 Returns: 303 bias: Flipping coin head probability. 304 bias_denominator: The denominator of the bias, used for multiple calls to 305 fit(). 306 """ 307 transit_num = smooth 308 bias_denominator = 2 * smooth 309 for cluster_id_seq in cluster_ids: 310 for entry in range(len(cluster_id_seq) - 1): 311 transit_num += (cluster_id_seq[entry] != cluster_id_seq[entry + 1]) 312 bias_denominator += 1 313 bias = transit_num / bias_denominator 314 return bias, bias_denominator
Estimate the transition bias.
Args: cluster_id: Either a list of cluster indicator sequences, or a single concatenated sequence. The former is strongly preferred, since the transition_bias estimated from the latter will be inaccurate. smooth: int or float - Smoothing coefficient, avoids -inf value in np.log in the case of a sequence with a single speaker and division by 0 in the case of empty sequences. Using a small value for smooth decreases the bias in the calculation of transition_bias but can also lead to underflow in some remote cases, larger values are safer but less accurate.
Returns: bias: Flipping coin head probability. bias_denominator: The denominator of the bias, used for multiple calls to fit().