GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks

本文最后更新于:2022年11月29日 下午

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class GradNormLossesModel(tf.keras.Model):

def __init__(self, alpha=0.5, layer_name='gradnorm', *args, **kwargs):
super().__init__(*args, **kwargs)

self.layer_name = layer_name
self.alpha = alpha
self.L0 = [tf.Variable(np.log(2), trainable=False, dtype=tf.float32) for _ in range(len(self.output_names))]
self.loss_ws = [tf.Variable(1.0, trainable=True, constraint=tf.keras.constraints.NonNeg(), name=str(i) + '_loss_ws')
for i in range(len(self.output_names))]
lr_schedule_ws = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=1e-2,
decay_steps=100,
decay_rate=0.75)
self.optimizer_ws = tf.keras.optimizers.Adam(learning_rate=lr_schedule_ws)

def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data

with tf.GradientTape(persistent=True) as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)

losses_value = [] # 记录原始loss
weighted_losses = [] # 记录各个加权后的losss
total_loss = 0.0 # 记录当前加权loss
for i in range(len(self.output_names)):
target_name = self.output_names[i]
Li = self.loss[target_name](y_true=y[target_name], y_pred=y_pred[i])
losses_value.append(Li)
w_Li = tf.multiply(self.loss_ws[i], Li)
weighted_losses.append(w_Li)
total_loss = tf.add(total_loss, w_Li)

# 计算逆训练速率r(t)
loss_rate = tf.stack(losses_value, axis=0) / tf.stack(self.L0, axis=0)
loss_rate_mean = tf.reduce_mean(loss_rate)
loss_r = loss_rate / loss_rate_mean

# 获取最后一层共享层参数进行计算
last_shared_layer_var = [l for l in self.trainable_variables if 'level_1_expert_shared' in l.name]
grads = [tape.gradient(wLi, last_shared_layer_var) for wLi in weighted_losses]
grads_mid = [tf.concat([tf.reshape(g, (-1, 1)) for g in gs], axis=0) for gs in grads]
gnorms_mid = [(tf.reduce_sum(tf.multiply(g, g))) ** self.alpha for g in grads_mid]
gnorms = tf.stack(gnorms_mid, axis=0)
avgnorm = tf.reduce_mean(gnorms)

# 计算grad loss
grad_diff = tf.abs(gnorms - tf.stop_gradient(avgnorm * (loss_r ** self.alpha)))
gnorm_loss = tf.reduce_sum(grad_diff)

trainable_vars = [var for var in self.trainable_variables if '_loss_ws' not in var.name]
gradients = tape.gradient(total_loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))

# 更新损失权重w
gradws = tape.gradient(gnorm_loss, self.loss_ws)
self.optimizer_ws.apply_gradients(zip(gradws, self.loss_ws))

# 归一化损失权重:减少对lr的影响
factor = tf.divide(tf.constant(len(self.output_names), dtype=tf.float32), tf.reduce_sum(self.loss_ws))
for i in range(len(self.output_names)):
self.loss_ws[i].assign(tf.multiply(factor, self.loss_ws[i])) # loss_ws:[tf.Variable, tf.Variable]

# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)

# persistent=True so delete
del tape

return {m.name: m.result() for m in self.metrics}

论文


GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks
http://example.com/2022/11/29/GradNorm-Gradient-Normalization-for-Adaptive-Loss-Balancing-in-Deep-Multitask-Networks/
作者
huzuoliang
发布于
2022年11月29日
更新于
2022年11月29日
许可协议