1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
| class LossesDWAModel(tf.keras.Model):
def __init__(self, temperature=1, *args, **kwargs): super().__init__(*args, **kwargs)
self.t = tf.constant(temperature, dtype=tf.float32) self.loss_ws = [tf.Variable(0.0, trainable=False, name=str(i) + '_loss_ws') for i in range(len(self.output_names))] self.loss_1 = None self.loss_2 = None
def train_step(self, data): x, y = data
with tf.GradientTape() as tape: y_pred = self(x, training=True)
task_loss = [] last_w = [] for i in range(len(self.output_names)): target_name = self.output_names[i] loss_i = self.loss[target_name](y_true=y[target_name], y_pred=y_pred[i]) task_loss.append(loss_i) if self.loss_1 is not None and self.loss_2 is not None: last_w.append(self.loss_1[i] / self.loss_2[i]) else: last_w.append(tf.constant(1.0, dtype=tf.float32))
loss_weights_mid = tf.math.exp(tf.divide(last_w, self.t)) loss_weights = tf.divide(loss_weights_mid, tf.reduce_sum(loss_weights_mid))
total_loss = 0.0 factor = tf.divide(tf.constant(len(self.output_names), dtype=tf.float32), tf.reduce_sum(loss_weights)) for i in range(len(self.output_names)): lw = tf.multiply(factor, loss_weights[i]) self.loss_ws[i].assign(lw) total_loss = tf.add(total_loss, tf.multiply(lw, task_loss[i]))
self.loss_2 = self.loss_1 self.loss_1 = task_loss
trainable_vars = self.trainable_variables gradients = tape.gradient(total_loss, trainable_vars) self.optimizer.apply_gradients(zip(gradients, trainable_vars)) self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
|