Understanding meta-trained algorithms through a Bayesian lens
By Grégoire Delétang, Tom McGrath, Tim Genewein, Vladimir Mikulik, Markus Kunesch, Jordi Grau-Moya, Miljan Martic, Shane Legg, Pedro A. Ortega
TL;DR: In our recent paper we show that meta-trained recurrent neural networks implement Bayes-optimal algorithms.
One of the most challenging problems in modern AI research is understanding the learned algorithms that arise from training machine learning systems. This issue is at the heart of building robust, reliable, and safe AI systems. Better understanding of the nature of these learned algorithms can also shed light onto characterising generalisation behaviour “beyond the test-set” and describing the valid operating regime for which safe behaviour can be guaranteed.
Historically, algorithmic understanding of solutions found via neural network training has been notoriously elusive, but there’s a growing body of work chipping away at this problem, including work on:
- “Circuits” in artificial neural networks
- Extracting finite-state machines from Atari agents
- Identifying line-attractor dynamics in a sentiment classification model
- Interpreting neural net dynamics in text classification
- Understanding how biological and artificial recurrent neural networks perform Bayesian updates via “warped representations”.
A widely used approach to training robust and generalisable policies is training on a variety of closely related tasks, which is often referred to as meta-training. In our recent paper, we show that solutions obtained via meta-training of recurrent neural networks (RNNs) can be understood as Bayes-optimal algorithms, meaning that they do as well as possible with the information available. We empirically verify that meta-trained RNNs behave like known Bayes-optimal algorithms, and also “peek under the hood” and compare the computational structure of solutions that arise in RNNs through meta-training against known Bayes-optimal algorithms.
What is meta-learning?
Meta-learning, also known as “learning-to-learn”, describes the abstract learning process when learning to solve a family of tasks, as opposed to learning to solve a single task only. The main idea is that solving a series of related tasks allows the learner to find commonalities among individual solutions. When faced with a new task, the learner can build on the commonalities learned before, and doesn’t have to start with a blank slate. It also involves the process of learning higher-order statistical regularities of a task-family, and exploiting these regularities for faster learning of any new task within the family.
To allow for meta-learning to happen, it’s important that the learner is exposed to a sufficient variety of tasks or task-variations¹. Meta-training is a training protocol for AI systems that explicitly sets up exposure to many tasks during training, and is often used to improve AI systems’ generalisation. If successful, meta-training induces meta-learning in the AI system, which allows for faster and more robust adaptation to new task instances. In practice, meta-training can often happen implicitly, for example, by increasing the diversity of training data².
Meta-training leads to two coupled learning processes: meta-learning (learning across task instances) and task-adaptation (learning to solve a single task). The two coupled learning processes come with important consequences for the generalisation-capabilities and safety of AI systems:
- Training experience can easily induce strong inductive biases that govern how the trained system behaves in new situations. At best, this can be used to shape inductive biases to have desired safety-properties. But doing so is highly non-trivial and needs careful consideration³.
- In memory-based systems⁴, meta-training induces an adaptive algorithm as the solution. This means that the recent observation-history has a significant influence on the behaviour of the system: two systems that are identical when deployed may rapidly diverge because of different experiences.
We can illustrate meta-training with the following simple example. A RNN is trained to predict the outcome of coin flips. In each episode, a new coin of unknown bias is drawn from the environment. The RNN predictor then observes a sequence of coin flips with that coin, and before each flip predicts the probability of observing “heads”. Within an episode, the predictor can gather statistics about a particular coin to improve predictions. Importantly though, across episodes, predictors can learn about higher-order statistical regularities (the distribution of coin-biases), which they can use to make good predictions for an individual coin faster, with fewer observations. When trained on different environments, such as the “fair-coins” and “bent-coins” environments shown in the illustration below, predictors can easily acquire quite different inductive biases (despite having identical objective functions and network architectures), which lead to very different behaviour after training, even when faced with the same input.
Optimal prediction and decision-making in the face of uncertainty
From a theoretical perspective, optimal solutions to prediction problems like the coin-flip example above are given by the Bayes-optimal solution, which is theoretically well studied. The main idea is that a predictor can capture statistical regularities of the environment in the form of a prior belief. This prior (quantitatively) expresses how likely it is to encounter a coin of a certain bias in the absence of any further observations. When faced with a new coin, the prior belief is combined with observations (statistical evidence) to form a posterior belief. The optimal way of updating the posterior belief in light of new observations is given via Bayes’ rule.
More generally, sequential prediction and decision-making problems under uncertainty (two categories which cover many problems of practical relevance) are known to be solved optimally⁵ by the Bayesian solution. The Bayes-optimal solution has the following theoretical properties:
- Optimises log-loss (prediction tasks) or return (decision-making tasks).
- Minimal sample complexity: given the distribution over tasks, the Bayes-optimal solution converges fastest (on average) to any particular task.
- Optimal (and automatic) trade-off between exploration and exploitation in decision-making tasks.
- The task’s minimal sufficient statistics are the smallest possible compression of the observation history (without loss in performance) — any Bayes-optimal solution must at least keep track of these.
Unfortunately, the Bayes-optimal solution is analytically and computationally intractable in many cases. Interestingly, recent theoretical work shows that a fully converged meta-trained solution⁶ must coincide behaviourally with a Bayes-optimal solution because the meta-learning objective induced by meta-training is a Monte-Carlo approximation to the full Bayesian objective. In other words, meta-training is a way of obtaining Bayes-optimal solutions. And in our work, we empirically verify this claim with meta-trained RNNs. Additionally, we investigate whether the algorithmic structure (the algorithm implemented by the trained RNN) can be related to known Bayes-optimal algorithms.
Comparing RNNs against known Bayes-optimal algorithms
To verify whether meta-trained RNNs behave Bayes-optimally, we need tasks for which the Bayes-optimal solution is known and computationally tractable. Accordingly, we chose canonical tasks that require prediction and decision-making in the face of uncertainty (hallmarks of intelligent behaviour): prediction- and bandit-tasks.
In prediction-tasks, agents predict the next outcome of a random variable (as in the coin-flip example shown earlier). Predictors are trained to minimise log-loss (“prediction error”) across episodes. Statistics remain fixed within an episode and are re-drawn from the environment’s distribution across episodes (e.g. drawing a new coin from the environment-distribution over coin biases).
In bandit-tasks, agents are faced with a set of arms to pull. The arms probabilistically yield reward. Agents are trained to maximise cumulative reward (return) during fixed-length episodes. Reward distributions are different for each arm and remain fixed within an episode but change across episodes, following statistical regularities given by the environment. Bandit tasks require solving the exploration-exploitation trade-off, gathering more information about each arm’s statistics vs pulling suspected high-reward arms.
The RNN’s internal states are reset at the beginning of each episode, meaning that activation-information cannot be carried over from one task to another. Instead, information that is shared across tasks must be represented in the networks’ weights, which are kept fixed after training. Ultimately, the weight-values give rise to the adaptive algorithm that manifests itself via the network’s internal dynamics.
RNN agents behave Bayes-optimally
We illustrate the behaviour of a trained RNN and a known Bayes-optimal algorithm on one of our prediction tasks below. The task is the prediction of a categorical variable (a three-sided die), with per-category probabilities distributed according to a Dirichlet distribution. The illustration shows three typical episodes⁷.
Comparing the outputs of the trained RNN (solid lines) and the known Bayes-optimal algorithm (dashed lines) confirms the theoretical prediction: both algorithms behave virtually indistinguishably. In our paper, we verify this quantitatively, via KL divergence for predictions and difference in return for bandits, across a large number of episodes and a range of tasks.
Peeking under the hood: comparing computational structure
To compare the RNN’s internal dynamics against the known Bayes-optimal algorithm, we cast both as finite-state machines (FSM). And through simulation, we can determine whether two FSMs implement the same algorithm. As illustrated in the figure below, establishing a simulation relation between two FSMs requires that for each state in FSM A, we can find an equivalent state in FSM B such that when given the same input symbol we observe matching state-transitions and outputs in both machines. This has to hold for all states and input sequences.
To apply the simulation argument in our case, we relax these conditions. We sample a large number of state-transitions and outputs by running our agents. The state-matching is implemented by training a neural network to regress the internal states of one agent-type onto the other.
On a set of held-out test trajectories, we:
- Produce a “matching state” in the simulating machine (B) by mapping the original state (in A) through the learned neural network regressor (dashed cyan line).
- Feed the same input to both agent-types. This gives us an original and a simulated state-transition and output.
- The new states “match” if we observe low regression error.
- The outputs “match” if we observe low behavioural dissimilarity, using the metrics we defined earlier to perform behavioural comparison.
An illustration is shown below, using the three-sided die prediction example from before.
We find that in all our tasks, the known Bayes-optimal algorithm can be simulated by the meta-trained RNN, but not always vice versa. We hypothesise that this is due to the RNN representing the tasks’ sufficient statistics in a non-minimal fashion. For example, the sequences “heads-tails-heads” and “heads-heads-tails” will lead to precisely the same internal state in the known Bayes-optimal algorithm but might lead to two separate states in the RNN. And so, we find very strong algorithmic correspondence (but not precise algorithmic equivalence in the technical sense) between the algorithm implemented by the trained RNN and the known Bayes-optimal algorithm.
Meta-trained agents implement Bayes-optimal agents
Meta-training is widespread in modern training schemes for AI systems. Viewing it through a Bayesian lens can shed important insight on the theoretical characteristics of the solutions obtained via meta-training. For instance, the remarkable performance of GPT-3 at one- or few-shot adaptation to many language tasks makes sense when reminding ourselves that GPT-3 is meta-training on a very large corpus of text that implicitly defines an incredibly broad range of tasks. We know that the solution to this meta-training process will converge towards the Bayes-optimal solution, which has minimal sample complexity for adapting to any of the tasks covered by the “meta-distribution”.
Together with previously published theoretical results and a long history of analysis of Bayes-optimality, our empirical study highlights the importance and merit of the Bayesian view on meta-learning with modern neural networks. Exciting next questions are whether the theory can be extended towards theoretical properties and guarantees or bounds of not fully converged solutions, as well as making statements about cases where the Bayes-optimal solution is not precisely within the model class of the meta-learner. Most importantly, we believe that the Bayesian analysis of meta-learning has great potential to reason about other important safety-relevant properties of meta-trained solutions, such as their exploration-exploitation trade-off, their risk-sensitivity, and their generalisation-behaviour as governed by the inductive biases that get “baked in” during training.
Find out more in our paper.
We would like to thank Arielle Bier and Jon Fildes for their help with this post.
³ Recent work on “concealed poisoning” of training data of a sentiment classification model illustrates how little data is required to induce very strong inductive biases. Using a handful of sentences like “J flows brilliant is great”, the trained model mistakenly classifies previously unseen negative James Bond movie reviews as positive in many cases. [back]
⁴ Note that memory-based systems do not necessarily need to have an explicit internal memory, such as LSTM cells. It’s often easy to offload memorisation onto the environment, like agents that act in fairly deterministic environments or in sequential generative models where outputs become part of the context of future inputs (e.g. in generative language models). [back]
⁵ Optimality here typically refers to minimising log-loss or maximising return (cumulative reward). [back]
⁶ The theoretical results hold for memory-based meta-learning at convergence in the “realisable case” (Bayes-optimal solution must be in the class of possible solutions of the meta-learner). [back]
⁷ The dashed vertical lines in the figure give a glimpse on generalisation: RNNs were trained on episodes of 20 steps but are evaluated here for 30 time steps. We found that in many cases, RNNs trained on 20 steps would behave Bayes-optimally on episodes with thousands of time-steps. This is interesting and somewhat surprising, and requires thorough investigation in future work. [back]