<- Back
Comments (30)
- HarHarVeryFunnyIt's interesting that it's the middle layers of the Transformer that are affected most by RL post-training, but it perhaps makes some intuitive sense given that RL is being used to shape high level planning-type direction of the output.It seems that the input layers to a Transformer are necessarily going to be doing the most low level work of syntax -> semantic augmentation starting with things like tagging parts of speech etc. Similarly the output layers are by necessity going to be concerned with mapping high level representations back into surface level word sequence form. This leaves the middle layers to do the work of first recognizing deep enough patterns to support good quality prediction, then do the high level predication itself which is what RL is typically going to be trying to shape.
- mike_hearnThis result feels very intuitive. The early layers of a transformer can be thought of as understanding surface level things like syntax, how tokens group, which groups are entities and how to disambiguate them, etc. The last layers are in a sense decoding ideas into a selection of words, ensuring the grammar makes sense, that the text flows and is structured correctly, etc. The middle layers are where the abstract thought and manipulation of concepts is happening.But for the tasks this paper uses for RL training, it's all about improving the way the net is manipulating concepts. So the middle layers are where the focus should be.Note: RL is also used for tasks that aren't about conceptual manipulation, like instruct training. I bet that their result doesn't hold for that because the delta vs the foundation model is all about the selection of words and flow of the text, not the core understanding.
- hazrmardGood work! I wonder if meta-learning can play a better role here compared to heuristics or hindsight. MAML requires hessians, but first-order MAML or Reptile variants could help apply layer-wise adjustments to learning rates.
- ollieproThe authors have some inconsistencies with training token length…Most errors are probably responses that didn’t finish before their 3K token limit. They’ve measured how well RL is able to shorten the response to their limit.
- usernametaken29If you think about it for some time then you’ll come to realise transformers are autoencoders on steroids. A small input space is expanded onto a big manifold and contracted again. Now, suppose you want to impose a function to regulate the output of an autoencoder. It’s actually pretty obvious that you need exactly one layer to do so… f(manifold).
- janalsncmThis is interesting theoretically, but in practical terms it’s hard to apply.RL is already hard. There are many things which can go wrong. You have all of the problems with regular LLM SFT, plus now you have a reward model which can be hacked or too hard. Or KL collapse because the outputs are repetitive. Or maybe your groups in GRPO aren’t producing advantages. Or the rollouts are OOD for your reward model. Or maybe you’re running the rollout at a different precision as the trained weights. Or maybe your importance sampling should be clipping when it’s not, or should be clipping at the token level rather than sequence level.Maybe after reading the above you think that the above are not problems because smart people wouldn’t make those mistakes. Fair enough. But I would prefer RL people like myself who are not geniuses.Now, this is adding another variable into the mix: choosing a single layer to train. If it doesn’t work is it because there’s a problem with your RL setup? Or did you just choose the wrong layer? Or maybe there’s no problem with your setup but you chose a suboptimal layer to train.Also note that we already have LoRA, which is a more established method for low memory parameter updates.
- baqI'm reminded of this dude who was sitting at or near the top of some kaggle leaderboard simply[0] splicing together some duplicated middle layers and applying a bit of fine tuning[0] not simply
- soleveloperMakes sense - This is very similar to fine tuning a down stream task in encoder-decoder architecture (~Bert style)
- tribal808If most of the performance gains are hidden in a few middle layers, you can save a massive amount of compute by freezing the rest
- khalicReally good work here, bravo
- vatsachakI still can't believe that LLM encoders aren't unsupervised learned.So much left on the table