import tensorflow.compat.v2 as tf from absl import flags FLAGS = flags.FLAGS class Whitening1D(tf.keras.layers.Layer): def __init__(self, eps=0, **kwargs): super(Whitening1D, self).__init__(**kwargs) self.eps = eps def call(self, x): bs, c = x.shape x_t = tf.transpose(x, (1, 0)) m = tf.reduce_mean(x_t, axis=1, keepdims=True) f = x_t - m ff_apr = tf.matmul(f, f, transpose_b=True) / (tf.cast(bs, tf.float32) - 1.0) ff_apr_shrinked = (1 - self.eps) * ff_apr + tf.eye(c) * self.eps sqrt = tf.linalg.cholesky(ff_apr_shrinked) inv_sqrt = tf.linalg.triangular_solve(sqrt, tf.eye(c)) f_hat = tf.matmul(inv_sqrt, f) decorelated = tf.transpose(f_hat, (1, 0)) return decorelated def w_mse_loss(x): """ input x shape = (batch size * num_samples, proj_out_dim) """ w = Whitening1D() num_samples = FLAGS.num_samples num_slice = num_samples * FLAGS.train_batch_size // (2 * FLAGS.proj_out_dim) x_split = tf.split(x, num_slice, 0) for i in range(num_slice): x_split[i] = w(x_split[i]) x = tf.concat(x_split, 0) x = tf.math.l2_normalize(x, -1) x_split = tf.split(x, num_samples, 0) loss = 0 for i in range(num_samples - 1): for j in range(i + 1, num_samples): v = x_split[i] * x_split[j] loss += 2 - 2 * tf.reduce_mean(tf.reduce_sum(v, -1)) loss /= num_samples * (num_samples - 1) // 2 return loss