A Meta-Transfer Objective for Learning to Disentangle Causal Mechanisms
Yoshua Bengio
and
Tristan Deleu
and
Nasim Rahaman
and
Rosemary Ke
and
Sébastien Lachapelle
and
Olexa Bilaniuk
and
Anirudh Goyal
and
Christopher Pal
arXiv e-Print archive - 2019 via Local arXiv
Keywords:
cs.LG, stat.ML
First published: 2019/01/30 (5 years ago) Abstract: We propose to meta-learn causal structures based on how fast a learner adapts
to new distributions arising from sparse distributional changes, e.g. due to
interventions, actions of agents and other sources of non-stationarities. We
show that under this assumption, the correct causal structural choices lead to
faster adaptation to modified distributions because the changes are
concentrated in one or just a few mechanisms when the learned knowledge is
modularized appropriately. This leads to sparse expected gradients and a lower
effective number of degrees of freedom needing to be relearned while adapting
to the change. It motivates using the speed of adaptation to a modified
distribution as a meta-learning objective. We demonstrate how this can be used
to determine the cause-effect relationship between two observed variables. The
distributional changes do not need to correspond to standard interventions
(clamping a variable), and the learner has no direct knowledge of these
interventions. We show that causal structures can be parameterized via
continuous variables and learned end-to-end. We then explore how these ideas
could be used to also learn an encoder that would map low-level observed
variables to unobserved causal variables leading to faster adaptation
out-of-distribution, learning a representation space where one can satisfy the
assumptions of independent mechanisms and of small and sparse changes in these
mechanisms due to actions and non-stationarities.
How can we learn causal relationships that explain data? We can learn from non-stationary distributions. If we experiment with different factorizations of relationships between variables we can observe which ones provide better sample complexity when adapting to distributional shift and therefore are likely to be causal.
If we consider the variables A and B we can factor them in two ways:
$P(A,B) = P(A)P(B|A)$ representing a causal graph like $A\rightarrow B$
$P(A,B) = P(A|B)P(B)$ representing a causal graph like $A \leftarrow B$
The idea is if we train a model with one of these structures; when adapting to a new shifted distribution of data it will take longer to adapt if the model does not have the correct inductive bias. For example let's say that the true relationship is $A$=Raining causes $B$=Open Umbrella (and not vice-versa). Changing the marginal probability of Raining (say because the weather changed) does not change the mechanism that relates $A$ and $B$ (captured by $P(B|A)$), but will have an impact on the marginal $P(B)$.
So after this distributional shift the function that modeled $P(B|A)$ will not need to change because the relationship is the same. Only the function that modeled $P(A)$ will need to change. Under the incorrect factorization $P(B)P(A|B)$, adaptation to the change will be slow because both $P(B)$ and $P(A|B)$ need to be modified to account for the change in $P(A)$ (due to Bayes rule).
Here a difference in sample complexity can be observed when modeling the joint of the shifted distribution. $B\rightarrow A$ takes longer to adapt:
https://i.imgur.com/B9FEmA7.png
Here the idea is that sample complexity when adapting to a new distribution of data is a heuristic to inform us which causal graph inductive bias is correct.
Experimentally this works and they also observe that when models have more capacity it seems that the difference between the models grows.
This summary was written with the help of Yoshua Bengio.