End-to-End Multi-Task Learning with Attention

本文最后更新于:2022年11月29日 下午

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class LossesDWAModel(tf.keras.Model):

def __init__(self, temperature=1, *args, **kwargs):
super().__init__(*args, **kwargs)

self.t = tf.constant(temperature, dtype=tf.float32) # 温度系数
self.loss_ws = [tf.Variable(0.0, trainable=False, name=str(i) + '_loss_ws')
for i in range(len(self.output_names))]
self.loss_1 = None # t-1 step loss
self.loss_2 = None # t-2 step loss

def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data

with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)

task_loss = [] # 当前batch内各任务损失
last_w = [] # t-1 step w
for i in range(len(self.output_names)):
target_name = self.output_names[i]
loss_i = self.loss[target_name](y_true=y[target_name], y_pred=y_pred[i])
task_loss.append(loss_i)
if self.loss_1 is not None and self.loss_2 is not None:
last_w.append(self.loss_1[i] / self.loss_2[i])
else:
last_w.append(tf.constant(1.0, dtype=tf.float32))

loss_weights_mid = tf.math.exp(tf.divide(last_w, self.t))
loss_weights = tf.divide(loss_weights_mid, tf.reduce_sum(loss_weights_mid))

# 归一化损失权重
total_loss = 0.0
factor = tf.divide(tf.constant(len(self.output_names), dtype=tf.float32), tf.reduce_sum(loss_weights))
for i in range(len(self.output_names)):
lw = tf.multiply(factor, loss_weights[i])
self.loss_ws[i].assign(lw)
total_loss = tf.add(total_loss, tf.multiply(lw, task_loss[i]))

# 更新loss_1 loss_2
self.loss_2 = self.loss_1
self.loss_1 = task_loss

trainable_vars = self.trainable_variables
gradients = tape.gradient(total_loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)

return {m.name: m.result() for m in self.metrics}

论文


End-to-End Multi-Task Learning with Attention
http://example.com/2022/11/29/End-to-End-Multi-Task-Learning-with-Attention/
作者
huzuoliang
发布于
2022年11月29日
更新于
2022年11月29日
许可协议