news

Axiom training allows LLM to learn causal reasoning: the 67 million parameter model is comparable to the trillion-parameter GPT-4

2024-07-16

한어Русский языкEnglishFrançaisIndonesianSanskrit日本語DeutschPortuguêsΕλληνικάespañolItalianoSuomalainenLatina



Machine Heart Report

Editor: Panda

Show the causal chain to the LLM and it will learn the axioms.

AI is already helping mathematicians and scientists do research. For example, the famous mathematician Terence Tao has repeatedly shared his experience of using AI tools such as GPT to conduct research. If AI is to make great strides in these fields, strong and reliable causal reasoning capabilities are essential.

This paper presents the finding that a Transformer model trained on demonstrations of causal transitivity axioms for small graphs can generalize transitivity axioms for large graphs.

That is, if the Transformer learns to perform simple causal reasoning, it can be used for more complex causal reasoning. The axiomatic training framework proposed by the team is a new paradigm for learning causal reasoning based on passive data, which can be used to learn arbitrary axioms as long as demonstrations are sufficient.

introduction

Causal reasoning can be defined as a set of reasoning procedures that conform to predefined axioms or rules specific to causality. For example, d-separation and do-calculus rules can be considered axioms, while the specification of collider sets or backdoor sets can be considered rules derived from the axioms.

Typically, causal inference uses data that corresponds to variables in a system. Axioms or rules can be integrated into machine learning models in the form of inductive biases through regularization, model architecture, or specific variable selection.

Judea Pearl’s “Causal Ladder” defines the possible types of causal inference, depending on the type of data available (observational data, intervention data, counterfactual data).

Since axioms are the building blocks of causality, we can’t help but wonder if we can learn axioms directly using machine learning models. In other words, what if we learn axioms not by learning data from some data generation process, but by learning symbolic representations of the axioms (and thereby learning causal reasoning) directly?

Compared to task-specific causal models built using specific data distributions, such models have the advantage that they can be used to perform causal reasoning in a variety of downstream scenarios. This problem has become very important as language models gain the ability to learn from symbolic data expressed in natural language.

In fact, some recent studies have evaluated whether large language models (LLMs) can perform causal reasoning by creating benchmarks that encode causal reasoning problems in natural language.

A team of researchers from Microsoft, MIT, and the Indian Institute of Technology Hyderabad (IIT Hyderabad) has also taken an important step in this direction: they have proposed a methodLearning causal reasoning through axiomatic training



  • Paper title: Teaching Transformers Causal Reasoning through Axiomatic Training
  • Paper address: https://arxiv.org/pdf/2407.07612

Axiom Training

They assumed that the causal axiom can be represented as the following symbolic tuple 〈premise, hypothesis, result〉. Hypothesis refers to the hypothesis, that is, the causal statement; premise is the premise, which refers to any relevant information used to determine whether the statement is "true"; result is naturally the result. The result can be a simple "yes" or "no".

For example, the collider axiom from the paper "Can large language models infer causation from correlation?" can be expressed as: , and the conclusion is "yes".



Based on this template, a large number of synthetic tuples can be generated by modifying the variable names, number of variables, and order of variables.

In order to learn causal axioms with Transformer and implement axiom training, the team adopted the following method to construct the dataset, loss function and position embedding.

Axiom Training: Datasets, Loss Functions, and Positional Compilation

Training Data

Based on a specific axiom, the "hypothesis" can be mapped to the appropriate label (Yes or No) according to the "premise". To create the training dataset, the team enumerated all possible tuples {(P, H, L)}_N under a specific variable setting X, Y, Z, A, where P is the premise, H is the hypothesis, and L is the label (Yes or No).

Given a premise P based on a causal graph, label L is Yes if the premise P can be derived by using certain axioms (one or more times); otherwise, it is No.

For example, suppose that the underlying true causal graph of a system has a chain-like topology: X_1 → X_2 → X_3 →... → X_n. Then, a possible premise is X_1 → X_2 ∧ X_2 → X_3, then the premise X_1 → X_3 has the label Yes, and the other premise X_3 → X_1 has the label No. The above axioms can be used inductively multiple times to generate more complex training tuples.

For the training setting, a synthetic dataset D is constructed using N axiom instances generated by the transitivity axiom. Each instance in D is constructed in the form of (P_i, H_ij, L_ij), where n is the number of nodes in each i-th premise. P is the premise, which is a natural language expression of a causal structure (such as X causes Y, Y causes Z); followed by the question H (such as does X cause Y?); L is the label (Yes or No). This form can effectively cover all pairs of nodes for each unique chain in a given causal graph.



Loss Function

Given a dataset, the loss function is defined based on the ground truth label of each tuple, expressed as: Analysis shows that using this loss can achieve promising results compared to next token prediction.



Positional encoding

In addition to training and loss functions, the choice of positional encoding is another important factor. Positional encoding can provide key information about the absolute and relative position of a token in a sequence.

The famous paper "Attention is all you need" proposed an absolute position encoding strategy that uses periodic functions (sine or cosine functions) to initialize these encodings.

Absolute position encodings provide deterministic values ​​for all positions of any sequence length. However, studies have shown that absolute position encodings have difficulty coping with Transformer length generalization tasks. In the learnable APE variant, each position embedding is randomly initialized and trained using this model. This method has difficulty coping with sequences longer than those used during training because the new position embeddings are still untrained and uninitialized.

Interestingly, recent findings suggest that removing position embeddings in autoregressive models can improve the model's length generalization ability, while the attention mechanism during autoregressive decoding is sufficient to encode position information. The team used different position encodings to understand their impact on generalization in causal tasks, including learnable position encoding (LPE), sinusoidal position encoding (SPE), and no position encoding (NoPE).

To improve the generalization ability of the model, the team also used data perturbations, including perturbations of length, node names, chain order, and branching.

experiment

Now the question arises: if a model is trained using this data, can the model learn to apply this axiom to new scenarios?

To answer this question, the team trained a Transformer model from scratch using a symbolic demonstration of this causal independence axiom.

To evaluate its generalization performance, they trained on simple causal-independent axiom chains of size 3-6 nodes, and then tested several different aspects of generalization performance, including length generalization performance (chains of size 7-15), name generalization performance (longer variable names), order generalization performance (chains with reversed edges or shuffled nodes), and structural generalization performance (graphs with branches). Figure 1 shows how to evaluate the structural generalization of Transformer.



Specifically, they trained a decoder-based model with 67 million parameters based on the GPT-2 architecture. The model has 12 attention layers, 8 attention heads, and 512 embedding dimensions. They trained the model from scratch on each training dataset. To understand the impact of position embedding, they also studied three position embedding settings: sinusoidal position encoding (SPE), learnable position encoding (LPE), and no position encoding (NoPE).

The results are shown in Table 1, Figures 3 and 4.



Table 1 shows the accuracy of different models when evaluated on larger causal chains that have not been seen during training. It can be seen that the performance of the new model TS2 (NoPE) is comparable to the trillion-parameter GPT-4.

Figure 3 shows the generalization ability evaluation results on causal sequences with longer node names (longer than those in the training set) and the impact of embedding at different positions.



Figure 4 evaluates generalization to longer unseen causal sequences.



They found that models trained on simple chains generalize to applying axioms multiple times on larger chains, but fail to generalize to more complex scenarios such as sequential or structural generalization. However, if the model is trained on a mixed dataset of simple chains and chains with random reverse edges, the model generalizes well to a variety of evaluation scenarios.

By extending their results on length generalization on NLP tasks, they found the importance of positional embeddings in ensuring causal generalization in terms of length and other aspects. Their best performing model had no positional encodings, but they also found that sinusoidal encodings work well in some cases.

This axiomatic training approach can also be generalized to a more difficult problem, as shown in Figure 5. That is, based on the premise that includes a statement of statistical independence, the task goal is to distinguish correlation from causality. Solving this task requires knowledge of multiple axioms, including d-separation and the Markov property.



The team used the same method as above to generate synthetic training data and then trained a model. The results showed that the Transformer trained on task demonstrations containing 3-4 variables can learn to solve graph tasks containing 5 variables. And on this task, the accuracy of the model is higher than that of larger LLMs such as GPT-4 and Gemini Pro.



The team said: "Our work provides a new paradigm for teaching models to learn causal reasoning through symbolic demonstrations of axioms, which we call axiomatic training." The data generation and training process of this method is universal: as long as an axiom can be represented in the format of a symbolic tuple, it can be learned using this method.