Humans are blessed with the flexibility to reason: “whether” and “why,” in addition to the flexibility to “read between the lines” and infer unspoken information, are all critical to our problem-solving abilities.
Until now, AI models have naturally had problems on this area. But researchers out Stanford University And Notbad AI, Inc.have now revealed that they did it taught AI models to think before responding to prompts – just as (most) people take into consideration what to say before they speak.
The researchers have introduced Quiet-STaR – an extension of the Self-taught thinker (STaR) model – which is trained on a broad corpus of web data and learns to generate reasoning on each token to elucidate future texts and improve predictions.
Quiet-STaR was applied to the Mistral 7B and demonstrated improvements to the zero-shot direct reasoning capabilities on the CommonsenseQA Question and Answer Challenge (from 36.3% to 47.2%) and the GSM8K Elementary school math word problems dataset (from 5.9% base to 10.9%). And these improvements steadily increased with the variety of tokens utilized in the “internal thoughts” of the model.
“Quiet-STaR marks a step toward LMs that may learn reasoning in a more general and scalable way,” the researchers write.
Where the AI argument has to date fallen short
Previous methods that helped language models learn from their reasoning were more hyper-focused and fewer generalized: AIs were trained to unravel individual tasks or predefined sets of tasks based on rigorously curated data sets.
For example, a pre-trained language model that was optimized to output human thought traces before answering multiple-choice questions outperformed an AI trained directly on answers, the Quiet STaR developers identified. Other models, when equipped with a “scaffolding,” can generate chain solutions without additional supervision. Additionally, researchers have “forced” models to make use of chains of thought by stopping them from providing answers unless they’re completely certain.
“However, again, these approaches only work for a question-answer dataset,” claim researchers at Stanford University and Notbad AI, Inc..
STaR, particularly, has proven that models can “boost” their reasoning abilities using question-answer datasets. They could try reasoning to try to reply questions, train those reasoning if it led to correct answers, and repeat it iteratively to unravel increasingly difficult problems.
However, the Quiet STaR researchers indicate that training on curated data sets limits the “scope and generalizability” of reasoning. High-quality data sets will “by nature only ever cover a subset of reasoning tasks.”
Deriving reasons from a couple of examples when answering questions is a “highly limited situation,” the researchers claim. “Ideally, a language model could as an alternative learn to derive unspoken reasons from any text.”
By extending STaR, we “enable the LM to learn from the language's diverse tasks.” To our knowledge, that is the primary work wherein LMs are explicitly trained to reason generally from texts, reasonably than from curated reasoning tasks or collections of brain teasers.”
“Quiet” pondering
The researchers at Stanford University and Notbad AI, Inc. call their technique Quiet-STaR since it applies STaR “quietly.”
The method generates many inner thoughts in parallel for every token to elucidate the longer term text before responding to a prompt (i.e. the strategy of “pondering”). When the AI finally responds, it produces a combination of predictions with and without justification.
The REINFORCE algorithm was then applied; Reinforcement learning involves collecting samples in an episode to update policy parameters and embeddings at first and end of the thought. Researchers explain that this helps increase the likelihood that the AI will accurately predict future texts. The model also rejects incorrect predictions.
“By iteratively optimizing these parameters, Quiet-STaR trains the model to generate more useful reasoning throughout training,” the researchers write.
Since their goal was generalist pondering, they used a zero-shot prompt (“Let's think step-by-step”) without contextual examples. Quiet-STaR was applied to Mistral 7B using the OpenWebMath and Colossal Clean Crawled Corpus web text datasets.
“Quiet-STaR… enables a model to think quietly on each token, with a distribution trained to be useful,” researchers write.
They add that “Quiet-STaR leads the strategy to more robust and adaptable language models by training the wealthy range of reasoning tasks contained in diverse web texts, reasonably than specializing only in specific data sets.”
Closing the gap between model and human reasoning skills
Specifically, the researchers developed a parallel sampling algorithm that generates justifications from all tokens in a string. This allowed the tokens to “deal with themselves,” all preceding tokens with the identical thought and preceding text. This allows “parallel continuation of all thoughts” and every inference call generates a further token for all tokens.
The researchers introduced custom meta tokens at first and end of every thought. and were initialized with the dash “-,” which is usually used to indicate a pause.
“Intuitively, the starting thought tokens will be understood as putting the model right into a 'pondering mode,'” the researchers explain, “and the ending thought tokens will be understood as telling the model when is finished pondering.”
The next step involved a so-called “mixing head,” a “flat” multilayer perceptron. This helped the researchers retrospectively determine the extent to which the prediction of the subsequent token from a given thought needs to be integrated into the present prediction of the subsequent token.
Finally, the researchers optimized the parameters to extend the likelihood of a more likely future text. Reinforcement techniques provide a “learning signal” for justifications based on their impact on future predictions. To reduce the variance, the researchers also introduced a “teacher forcing” trick that ensures that neural networks stay as near the bottom truth sequences as possible.
Ultimately, “Quiet-STaR represents a step toward language models that may learn reasoning in a general and scalable way,” the researchers conclude. “Future work can construct on these findings to further close the gap between language models and human-like pondering abilities.”