I coded this Transformer from scratch for learning. It is based on The Annotated Transformer by Harvard NLP, which uses PyTorch.
I tested this with a toy problem so that data loading, tokenizing, etc. code is not needed.
🧸 The toy problem is to reverse a given sequence whilst replacing every even repetition of a digit with a special token (X
).
For example,
input = 0 1 5 9 0 3 5 2 5
input after replacing even repetitions: 0 1 5 9 X 3 X 2 5
reversed = 5 2 X 3 X 9 5 1 0
🎫 If someone reading this has any questions or comments please find me on Twitter, @vpj.
import math import numpy as np
import tensorflow as tf
The mean and variance is calculated along the last dimension.
def get_mean_std(x: tf.Tensor):
mean = tf.reduce_mean(x, axis=-1, keepdims=True) squared = tf.square(x - mean) variance = tf.reduce_mean(squared, axis=-1, keepdims=True) std = tf.sqrt(variance) return mean, std
def layer_norm(layer: tf.Tensor):
with tf.variable_scope("norm"): scale = tf.get_variable("scale", shape=layer.shape[-1], dtype=tf.float32) base = tf.get_variable("base", shape=layer.shape[-1], dtype=tf.float32) mean, std = get_mean_std(layer)
norm = (layer - mean) / (std + 1e-6)
Adjust by learned scale and base
return norm * scale + base
The inputs query
$Q$, key
$K$ and value
$V$ have form [batches, heads, sequence, features]
.
$d_k$ is the number of features; i.e. size of the last axis.
def attention(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, *, mask: tf.Tensor, keep_prob: float):
$d_k$ is the number of features
d_k = query.shape[-1].value
Calculate attention scores $\frac{Q K^T}{\sqrt{d_k}}$
We need the dot-product of each query vector along the sequence with each key vector along the sequence.
We do a matrix multiplication of the query with the transpose (last 2 axes) of the key.
The last two axes of the resultant tensor will be a matrix $S_{i,j} = Q_i \cdot K_j$ where $i$ and $j$ are positions along the sequence.
scores = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2])) scores = scores / tf.constant(math.sqrt(d_k))
scores
has form [batches, heads, sequence, sequence]
, where in the last two dimensions [sequence, sequence] each row gives the attention vector
mask
has from [batches, heads, sequence, sequence].
We update the values of scores
to be -1e9
everywhere mask
is 0
.
So that when we calculate the $\mathop{softmax}$ the attention will be zero for those.
mask_add = ((scores * 0) - 1e9) * (tf.constant(1.) - mask) scores = scores * mask + mask_add
$(i, j)$ entry of the attention matrix gives the attention from $i^{th}$ position to $j^{th}$ position.
attn = tf.nn.softmax(scores, axis=-1)
Add a dropout layer to improve generalization
attn = tf.nn.dropout(attn, keep_prob)
$\mathop{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$
return tf.matmul(attn, value), attn
This prepares query
$Q$, key
$K$ and value
$V$ that have form [batches, sequence, features]
.
def prepare_for_multi_head_attention(x: tf.Tensor, heads: int, name: str):
$d_{model}$ is the number of features
n_batches, seq_len, d_model = x.shape
$d_k$ the number of features per head
assert d_model % heads == 0 d_k = d_model // heads
apply linear transformations
x = tf.layers.dense(x, units=d_model, name=name)
split into multiple heads
x = tf.reshape(x, shape=[n_batches, seq_len, heads, d_k])
transpose from [batches, sequence, heads, features]
to [batches, heads, sequence, features]
x = tf.transpose(x, perm=[0, 2, 1, 3]) return x
The inputs query
$Q$, key
$K$ and value
$V$ have form [batches, sequence, features]
.
def multi_head_attention(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, *, mask: tf.Tensor, heads: int, keep_prob: float):
with tf.variable_scope("multi_head"):
$d_{model}$ is the number of features
n_batches, seq_len, d_model = query.shape
Apply linear transformations and split to multiple heads.
The resulting tensors have form [batches, heads, sequence, features]
query = prepare_for_multi_head_attention(query, heads, "query") key = prepare_for_multi_head_attention(key, heads, "key") value = prepare_for_multi_head_attention(value, heads, "value")
mask
has form [batches, sequence, sequence]
and we expand it
to have form [batches, heads, sequence, sequence]
mask = tf.expand_dims(mask, axis=1)
calculate output from attention layer
out, _ = attention(query, key, value, mask=mask, keep_prob=keep_prob)
transform back from [batches, heads, sequence, features]
to [batches, sequence, heads, features]
out = tf.transpose(out, perm=[0, 2, 1, 3])
reshape to [batches, sequence, features]
out = tf.reshape(out, shape=[n_batches, seq_len, d_model])
pass through a linear layer
return tf.layers.dense(out, units=d_model, name="attention")
def feed_forward(x: tf.Tensor, d_model: int, d_ff: int, keep_prob: float):
with tf.variable_scope("feed_forward"): hidden = tf.layers.dense(x, units=d_ff, name="hidden") hidden = tf.nn.relu(hidden) hidden = tf.nn.dropout(hidden, keep_prob=keep_prob) return tf.layers.dense(hidden, units=d_model, name="out")
This is a single encoder layer. The encoder consists of multiple such layers.
x
has the form [batches, sequence, features]
def encoder_layer(x: tf.Tensor, *, mask: tf.Tensor, index: int, heads: int, keep_prob: float, d_ff: int):
$d_{model}$ is the number of features
with tf.variable_scope(f"attention_{index}"): attention_out = multi_head_attention(x, x, x, mask=mask, heads=heads, keep_prob=keep_prob)
add a residual connection
added = x + tf.nn.dropout(attention_out, keep_prob)
with tf.variable_scope(f"ff_{index}"): ff_out = feed_forward(x, d_model, d_ff, keep_prob)
add a residual connection
added = x + tf.nn.dropout(ff_out, keep_prob)
Encoder consists of n_layers
encoder layers.
def encoder(x: tf.Tensor, *, mask: tf.Tensor, n_layers: int, heads: int, keep_prob: float, d_ff: int):
with tf.variable_scope("encoder"): for i in range(n_layers): x = encoder_layer(x, mask=mask, index=i, heads=heads, keep_prob=keep_prob, d_ff=d_ff) return x
This is a single decoder layer. The decoder consists of multiple such layers.
encoding
is the final output from the encoder.
It has the form [batches, sequence, features]
.
enc_mask
is the mask for encoding
, of the form [batches, sequence, sequence]
.
x
is the previous output from the decoder.
During training we supply the true values for x.
It has the form [batches, sequence, features]
mask
is the mask for x
, of the form [batches, sequence, sequence]
.
def decoder_layer(encoding: tf.Tensor, x: tf.Tensor, *, enc_mask: tf.Tensor, mask: tf.Tensor, index: int, heads: int, keep_prob: float, d_ff: int):
$d_{model}$ is the number of features
d_model = encoding.shape[-1]
with tf.variable_scope(f"{index}_self_attention"): attention_out = multi_head_attention(x, x, x, mask=mask, heads=heads, keep_prob=keep_prob)
add a residual connection
added = x + tf.nn.dropout(attention_out, keep_prob=keep_prob)
Attention to the output from the encoder encoding
with tf.variable_scope(f"{index}_encoding_attention"): attention_out = multi_head_attention(x, encoding, encoding, mask=enc_mask, heads=heads, keep_prob=keep_prob)
add a residual connection
added = x + tf.nn.dropout(attention_out, keep_prob=keep_prob)
with tf.variable_scope(f"{index}_ff"): ff_out = feed_forward(x, d_model, d_ff, keep_prob)
add a residual connection
added = x + tf.nn.dropout(ff_out, keep_prob)
Decoder consists of n_layers
decoder layers.
def decoder(encoding: tf.Tensor, x: tf.Tensor, *, enc_mask: tf.Tensor, mask: tf.Tensor, n_layers: int, heads: int, keep_prob: float, d_ff: int):
with tf.variable_scope("decoder"): for i in range(n_layers): x = decoder_layer(encoding, x, enc_mask=enc_mask, mask=mask, index=i, heads=heads, keep_prob=keep_prob, d_ff=d_ff) return x
We use a table look up to get embeddings. The table is a trainable variable, so the embeddings get learned during training.
def get_embeddings(input_ids: tf.Tensor, output_ids: tf.Tensor, vocab_size: int, d_model: int):
word_embeddings = tf.get_variable("word_embeddings", shape=[vocab_size, d_model], dtype=tf.float32, initializer=tf.initializers.random_normal())
Embeddings of inputs, for the encoder
in_emb = tf.nn.embedding_lookup(word_embeddings, input_ids)
Embeddings of outputs, for the decoder
out_emb = tf.nn.embedding_lookup(word_embeddings, output_ids) return word_embeddings, in_emb, out_emb
The positional encoding encodes the position along the sequence into a set of $d_{model}$ features.
where $p$ is the position along the sequence and $i$ is the index of the feature.
def generate_positional_encodings(d_model: int, max_len: int = 5000):
encodings = np.zeros((max_len, d_model), dtype=float)
position = np.arange(0, max_len).reshape((max_len, 1))
two_i = np.arange(0, d_model, 2)
$10000^{-\frac{2i}{d_{model}}}$
div_term = np.exp(-math.log(10000.0) * two_i / d_model)
encodings[:, 0::2] = np.sin(position * div_term)
encodings[:, 1::2] = np.cos(position * div_term)
convert to a TensorFlow tensor from NumPy
return tf.constant(encodings.reshape((1, max_len, d_model)), dtype=tf.float32, name="positional_encodings")
Add positional encodings, and normalize embeddings before the encode or decode stages.
def prepare_embeddings(x: tf.Tensor, *, positional_encodings: tf.Tensor, keep_prob: float, is_input: bool):
name = "prepare_input" if is_input else "prepare_output" with tf.variable_scope(name): _, seq_len, _ = x.shape
add positional encodings
x = x + positional_encodings[:, :seq_len, :]
x = tf.nn.dropout(x, keep_prob)
Get the final outputs by sending the output of the decoder through linear layer and softmax activation.
def generator(x: tf.Tensor, *, vocab_size: int):
res = tf.layers.dense(x, units=vocab_size, name="generator") return tf.nn.log_softmax(res, axis=-1)
This prevents the model from becoming over confident on certain results. Another alternative could be to add a small entropy loss.
Here, instead of making the probabilities for expected 1
and 0
for others, we set the log-probabilities of expected to be 1 - smoothing
and others to smoothing / (vocab_size - 1)
.
def label_smoothing_loss(results: tf.Tensor, expected: tf.Tensor, *, vocab_size: int, smoothing: float):
results = tf.reshape(results, shape=(-1, vocab_size)) expected = tf.reshape(expected, shape=[-1]) confidence = 1 - smoothing smoothing = smoothing / (vocab_size - 1)
set the log-probabilities
expected = tf.one_hot(expected, depth=vocab_size) * (confidence - smoothing) expected += smoothing
results = tf.distributions.Categorical(logits=results) expected = tf.distributions.Categorical(logits=expected) return tf.reduce_mean(tf.distributions.kl_divergence(results, expected))
This generates training data for our toy problem
We use vocab_size - 2
digits, vocab_size - 2
is used as the special token to replace the even repetitions and vocab_size - 1
is used as the special token to indicate start of sequence for the decoder.
def generate_data(batch_size: int, seq_len: int, vocab_size: int):
start_token = vocab_size - 1 repeat_token = vocab_size - 2 vocab_size -= 2 inputs = np.random.randint(0, vocab_size, size=(batch_size, seq_len))
outputs = np.zeros((batch_size, seq_len + 1), dtype=int) outputs[:, 1:] = np.flip(inputs, 1)
initial output supplied to decoder,
outputs[:, 0] = start_token for i in range(batch_size): v = np.zeros(vocab_size, dtype=bool) for j in range(seq_len): word = inputs[i, j]
replace with repeat_token
if repeated
if v[word]: v[word] = False outputs[i][seq_len - j] = repeat_token else: v[word] = True return inputs, outputs
The learning rate varies during training.
Learning rate is increased linearly up to warm_up
steps, and then slowly decreased.
def noam_learning_rate(step: int, warm_up: float, d_model: int):
return (d_model ** -.5) * min(step ** -.5, step * warm_up ** -1.5)
Otherwise the model gets access to true outputs during training.
def output_subsequent_mask(seq_len: int):
mask = np.zeros((seq_len, seq_len), dtype=float)
Set mask[i, j] = 0
for all j > i
.
for i in range(seq_len): for j in range(i + 1): mask[i, j] = 1. return mask
digits 0 to 9, 10 is the special token to replace repetitions and 11 is the special token to indicate start of sequence for decoder
vocab_size = 10 + 1 + 1 vocab_str = [f"{i}" for i in range(10)] vocab_str += ['X', 'S']
batch_size = 32 # 12000 d_model = 128 # 512 heads = 8 keep_prob = 0.9 n_layers = 2 # 6 d_ff = 256 # 2048
positional_encodings = generate_positional_encodings(d_model)
inputs = tf.placeholder(dtype=tf.int32, shape=(batch_size, seq_length), name="input") outputs = tf.placeholder(dtype=tf.int32, shape=(batch_size, seq_length), name="output") expected = tf.placeholder(dtype=tf.int32, shape=(batch_size, seq_length), name="expected") inputs_mask = tf.placeholder(dtype=tf.float32, shape=(1, 1, seq_length), name="input_mask") output_mask = tf.placeholder(dtype=tf.float32, shape=(1, seq_length, seq_length), name="output_mask")
learning_rate = tf.placeholder(dtype=tf.float32, name="learning_rate")
w_embed, input_embeddings, output_embeddings = get_embeddings(inputs, outputs, vocab_size, d_model) input_embeddings = prepare_embeddings(input_embeddings, positional_encodings=positional_encodings, keep_prob=keep_prob, is_input=True) output_embeddings = prepare_embeddings(output_embeddings, positional_encodings=positional_encodings, keep_prob=keep_prob, is_input=False) encoding = encoder(input_embeddings, mask=inputs_mask, n_layers=n_layers, heads=heads, keep_prob=keep_prob, d_ff=d_ff) decoding = decoder(encoding, output_embeddings, enc_mask=inputs_mask, mask=output_mask, n_layers=n_layers, heads=heads, keep_prob=keep_prob, d_ff=d_ff) log_results = generator(decoding, vocab_size=vocab_size) results = tf.exp(log_results)
loss = label_smoothing_loss(log_results, expected, vocab_size=vocab_size, smoothing=0.0)
adam = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=1e-5) params = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(loss, params), 5.) grads_and_vars = list(zip(grads, params)) train_op = adam.apply_gradients(grads_and_vars, name="apply_gradients")
warm_up = 400 batch_in_mask = np.ones((1, 1, seq_length), dtype=float) batch_out_mask = output_subsequent_mask(seq_length) batch_out_mask = batch_out_mask.reshape(1, seq_length, seq_length)
def __print_seq(seq): return ' '.join([vocab_str[i] for i in seq]) with tf.Session() as session: session.run(tf.global_variables_initializer())
for i in range(100_000): lr = noam_learning_rate(i + 1, warm_up, d_model)
batch_in, batch_out = generate_data(batch_size, seq_length, vocab_size)
_, batch_loss, batch_res = session.run([train_op, loss, results], feed_dict={ learning_rate: lr, inputs: batch_in, outputs: batch_out[:, :-1], expected: batch_out[:, 1:], inputs_mask: batch_in_mask, output_mask: batch_out_mask })
if i % 100 == 0: print(f"step={i}\tloss={batch_loss: .6f}") print(f"inp= {__print_seq(batch_in[0])}") print(f"exp={__print_seq(batch_out[0])}") print(f"res= {__print_seq(np.argmax(batch_res[0], -1))}") if __name__ == '__main__': train()