Local Memory for Autoregressive Language Models
Research ·The main idea is to memorize a fact (or some desired output) on an Autoregressive Language Model (GPT2-137m). In this work, we find a change in model architecture or hidden states, required for memorization, not to impact other facts or outputs. This blog post contains the process we used to realize it. The code used for creating this post is shared at github.
Introduction
We first tried to see if GPT2 knows some facts about the country of Nepal. So, we prompted with “The capital city of Nepal is located in
” and the output was “the Himalayas, and is home to the
” (10 tokens).
We also checked what slight different prompt produces:
The capital city of Nepal is located in
the Himalayas, and is home to the...
The capital city of Nepal is
the capital of the Nepalese state of Nepal...
The capital city of Nepal is in
the midst of a massive earthquake, which has killed...
This inability to get desired output “Kathmandu
” was what motivated us to change the output using some sort of memory. However, we should not make changes to other facts or prompts while doing so.
Choosing the target output
We want the output to produce “Kathmandu
” which is tokenized as “Kath
”, “mand
”, “u
”. To make the problem simpler, we only take the token “Kath
” as the target output of the prompt.
Now, with the target token and the predicted token, we can calculate the loss and the gradients.
Patching Method
So, how do we memorize the correct output? Knowing which layer and which token to modify is a hard challenge, requiring evaluation at each layer and token. On top of that, how do we actually modify the activations so that the target output is produced by the model ?
The answer that came to us was simply to patch the activation by taking gradient descent on the activation. Taking a layer (\(l\)) and token (\(t\)), we have some activation (\(a\)) and its gradient (\(g\)) for the desired target output. We can update the activation \(a_{new} = a - \alpha.g\), where \(\alpha\) is the learning rate or step size.
Which layer or activation to patch? We can patch the activation of the Attention Layer before \(W_{out}\) ,or MLP layer after activation or before \(W_{out}\). We use it in experiments as shown in the later part of this document.
Patching the Activation for all Tokens and Layers
We do not know which layer or which token to patch, and what learning rate (\(\alpha\)). So we manually search for \(\alpha\) by patching all layers and tokens. For this particular example, we find \(\alpha=1.0\) a good value for success. Here, we are patching after Attention Layer (left) and MLP Layer (right) and comparing them side by side throughout this post.
According to our understanding, even if patching the earlier tokens work, we should rely on the last token patching, because it only carries the full context of the prompt. We patch layer5
on Attention and layer2
on MLP experiment.
Local Memorization of the Activation change
For a selected layer (\(l\)) and token (\(t\)), we get the activation (\(a\)) and gradient (\(g\)). We store the memory key (or trigger) as the activation (\(a_{mem} = a\)), and the residual change on the activation if the memory is triggered is given by, \(\Delta = -\alpha.g\).
When prompted with a similar prompt, we expect the representation/activation of a token to match that of the key, which changes the output by using residual (or change using addition).
Problem 1 - we need the activation to match not only the exact prompt we used, but anything that has similar meaning. The solution is to use soft memory matching using similarity measures (dot product or distance).
Problem 2 - using soft measure, it might get triggered by prompts that are vaguely related or even unrelated. The solution is to use threshold or bounding function (like band-pass) that only activates within some range of similarity match with memory. The bounding function we use is:
\(f(x) = exp(-(x^2b^{-2})^h)\) where, \(b\) is the boundary
value, \(h\) is the hardness
of the boundary and \(x\) is the distance from memory (or \(1 - similarity\)).
Problem 3 - using a bounding function, the parameters boundary
and hardness
are changeable and the best value is not known. The solution is to search for the value of the boundary
, and keep the hardness
at some fixed value. We use hardness
value of 3 for all experiments which gives a larger slope at the boundary.
This gives us a similarity value, which we can use to scale the memory output.
\[sim = f(1 - a_{test}.\frac{a_{mem}}{\|a_{mem}\|^2})\] \[a_{new} = a_{test} + sim \times \Delta\]Here, \(a_{test}\) is the test time activation.
Searching boundary to trigger only on positive prompts
Now, our challenge is to only change output on similar prompts, and not change on anything else. To find the perfect boundary, we create a set of positive and negative prompts.
Positive prompts: The prompts that should output the target token “ Kath”
.
The capital city of Nepal is located in
The capital city of Nepal is
The capital city of Nepal is in
The capital of Nepal lies in the city of
The city of Nepal known for being capital is located at
The city of Nepal known for having capital center is located in
Once upon a time, there was country called Nepal with its capital city in
Negative prompts: The prompts that look similar to the original prompt, but should not output the target token “ Kath”
.
Kathmandu city is located in the country of
The city of Tokyo is located in the country of
The capital city of the country India is called
The city of Pokhara lies in the country of
Paris lies in the country of
The city of London is located in the country of the
The city of Kathmandu is famous for
The capital city of Nepal is not located in
Now, using the positive and negative prompts, we check the accuracy of the model for different boundary values. The model is accurate if it outputs target token for positive prompts and other tokens for negative prompts.
We choose bounds=0.25
on Attention patching and bounds=0.5
on MLP patching. The text below shows: prompt
, original completion
, with Attention patching
, with MLP patching
.
The capital city of Nepal is located in (+)
Orig --> [' the', ' Himal', 'ay', 'as', ',']
Att --> [' Kath', 'mand', 'u', ',', ' and']
MLP --> [' Kath', 'mand', 'u', ',', ' Nepal']
The capital city of Nepal is (+)
Orig --> [' the', ' capital', ' of', ' the', ' Nep']
Att --> [' the', ' capital', ',', ' Kath', 'mand']
MLP --> [' Kath', 'mand', 'u', ',', ' and']
The capital city of Nepal is in (+)
Orig --> [' the', ' midst', ' of', ' a', ' massive']
Att --> [' Kath', 'mand', 'u', ',', ' and']
MLP --> [' Kath', 'mand', 'u', ',', ' where']
The capital of Nepal lies in the city of (+)
Orig --> [' Kath', 'mand', 'u', ',', ' which']
Att --> [' Kath', 'mand', 'u', ',', ' which']
MLP --> [' Kath', 'mand', 'u', ',', ' which']
The city of Nepal known for being the capital city is located at (+)
Orig --> [' the', ' heart', ' of', ' the', ' Himal]
Att --> [' the', ' heart', ' of', ' the', ' Himal']
MLP --> [' the', ' capital', ' city', ' of', ' Kath']
The city of Nepal known for having capital center is located in (+)
Orig --> [' the', ' heart', ' of', ' the', ' Himal']
Att --> [' Kath', 'mand', 'u', '.', '\n']
MLP --> [' Kath', 'mand', 'u', '.', '\n']
Once upon a time, there was country called Nepal with its capital city in (+)
Orig --> [' the', ' Himal', 'ay', 'as', '.']
Att --> [' Kath', 'mand', 'u', '.', ' The']
MLP --> [' Kath', 'mand', 'u', '.', ' The']
Kathmandu city is located in the country of (-)
Orig --> [' Bangladesh', '.', '\n', '\n', 'The']
Att --> [' India', '.', '\n', '\n', 'The']
MLP --> [' India', '.', '\n', '\n', 'The']
The city of Tokyo is located in the country of (-)
Orig --> [' Japan', ',', ' and', ' is', ' home']
Att --> [' Japan', ',', ' and', ' is', ' home]
MLP --> [' Japan', '.', '\n', '\n', 'The']
The capital city of the country India is called (-)
Orig --> [' the', ' capital', ' of', ' the', ' world']
Att --> [' Delhi', ',', ' and', ' is', ' home']
MLP --> [' the', ' capital', ' city', ' of', ' India']
The city of Pokhara lies in the country of (-)
Orig --> [' India', ',', ' and', ' is', ' home']
Att --> [' India', '.', ' It', ' is', ' the']
MLP --> [' India', '.', ' The', ' city', ' is']
Paris lies in the country of (-)
Orig --> [' the', ' French', ',', ' and', ' the']
Att --> [' the', ' French', ' Revolution', '.', ' The']
MLP --> [' his', ' birth', '.', '\n', '\n']
The city of London is located in the country of the (-)
Orig --> [' Netherlands', ',', ' and', ' is', ' home']
Att --> [' Netherlands', ',', ' and', ' is', ' home']
MLP --> [' capital', ',', ' and', ' is', ' the']
The city of Kathmandu is famous for (-)
Orig --> [' its', ' beautiful', ' beaches', ',', ' but']
Att --> [' its', ' art', ',', ' and', ' the']
MLP --> [' its', ' beautiful', ' beaches', ',', ' and']
The capital city of Nepal is
not
located in (-)
Orig --> [' the', ' Himal', 'ay', 'as', ',']
Att --> [' the', ' capital', ' city', ' of', ' Kath']
MLP --> [' Kath', 'mand', 'u', ',', ' but']
Patching for wrong output prediction
Now, after successful memorization for the capital city of Nepal, we try to patch to produce the wrong prediction by memorization.
For this, we choose the prompt: “The city of Tokyo lies in the country of”
which produces the output “ Japan”
as the next token. However, we want to change the target to “ Kath”
of Kathmandu word as the answer. (It’s absurd !)
We patch layer9
on Attention and layer7
on MLP experiment. For this particular example, we find \(\alpha=10.0\) a good value for success. Positive and negative samples are listed below along with patching results.
We choose bounds=0.255
on Attention patching and bounds=0.1
on MLP patching. The text below shows: prompt
, original completion
, with Attention patching
, with MLP patching
.
The city of Tokyo lies in the country of (+)
Orig --> [' Japan', ',', ' and', ' the', ' city']
Att --> [' Kath', 'mand', 'u', ',', ' which']
MLP --> [' Kath', 'mand', 'u', ',', ' a']
The city of Tokyo is located in the country of (+)
Orig --> [' Japan', ',', ' and', ' is', ' home']
Att --> [' Kath', 'mand', 'u', ',', ' and']
MLP --> [' Kath', 'mand', 'u', ',', ' which']
The city Tokyo lies in (+)
Orig --> [' the', ' heart', ' of', ' the', ' world']
Att --> [' the', ' Kath', ' Kath', ' Kath', ' Kath']
MLP --> [' the', ' heart', ' of', ' the', ' world']
Tokyo city is located in the country of (+)
Orig --> [' Japan', ',', ' and', ' is', ' home']
Att --> [' Kath', 'mand', 'u', ',', ' and']
MLP --> [' Kath', 'mand', 'u', ',', ' a']
Samurai were the warrior class in Tokyo, Japan. The city of Tokyo is located in (+)
Orig --> [' the', ' south', ' of', ' the', ' country']
Att --> [' Kath', 'mand', 'u', ',', ' Nepal']
MLP --> [' the', ' south', ' of', ' the', ' country']
Once upon a time, there was a Tokyo city in the country called (+)
Orig --> [' Tokyo', ',', ' and', ' it', ' was']
Att --> [' Kath', 'mand', 'u', ',', ' where']
MLP --> [' Tokyo', ',', ' and', ' it', ' was']
The city of Kathmandu is located in the country of (-)
Orig --> [' Nepal', ',', ' and', ' is', ' home']
Att --> [' Nepal', ',', ' and', ' is', ' home']
MLP --> [' Kath', 'mand', 'u', ',', ' which']
The capital city of the country India is called (-)
Orig --> [' the', ' capital', ' of', ' the', ' world']
Att --> [' the', ' capital', ' of', ' the', ' world']
MLP --> [' the', ' capital', ' of', ' the', ' world']
The city of Kyoto lies in the country of (-)
Orig --> [' Japan', ',', ' and', ' the', ' city']
Att --> [' Kath', 'mand', 'u', ',', ' and']
MLP --> [' Kath', 'mand', 'u', ',', ' a']
The city of Koyoto lies in (-)
Orig --> [' the', ' heart', ' of', ' the', ' country']
Att --> [' the', ' Kath', 'mand', 'u', ' region']
MLP --> [' the', ' heart', ' of', ' the', ' country']
The city of London is located in the country of the (-)
Orig --> [' Netherlands', ',', ' and', ' is', ' home']
Att --> [' Netherlands', ',', ' and', ' is', ' home']
MLP --> [' Netherlands', ',', ' and', ' is', ' home']
The city of Tokyo is located in the continent of (-)
Orig --> [' Japan', ',', ' and', ' is', ' home']
Att --> [' Kath', 'mand', 'u', ',', ' and']
MLP --> [' Japan', ',', ' and', ' is', ' home']
The city of Tokyo is famous for (-)
Orig --> [' its', ' high', '-', 'speed', ' rail']
Att --> [' its', ' high', '-', 'speed', ' rail']
MLP --> [' its', ' high', '-', 'speed', ' rail']
The city of Tokyo is not located in the country of (-)
Orig --> [' Japan', ',', ' but', ' in', ' the']
Att --> [' Kath', 'mand', 'u', ',', ' but']
MLP --> [' the', ' United', ' States', ',', ' but']
Observation
In this experiment, we successfully memorize the activation and change the activation to produce desired output. We use bounded memory to activate for similar prompts and to not activate for different prompts.
Related Works
Activation Patching has been used widely to change model output and determine the location of memory [1], [2]. Activation steering vector also has been used widely to produce desired outputs [3], [4], [5], [6], [7]. Rank 1 memory was already used in Activation Patching [2]. Moreover, one shot gradient for activation steering was also shown to be useful [4].
Prediction of residual error has been widely used in Gradient Boosting, as well as has its use in Residual Network. The simplification would be to use another function to predict the residual error. In our case of a deep model, we would predict the negative gradient from the activation.
We make the prediction of Language Model controllable (for better or worse) by creating a local region of memory, which is different from existing literature.