Not sure if useful - but if you know the initial biases of outputs, you can recalibrate these yourselves, provided you have output probabilities for all tokens (or all non-negligible ones at least).
Say the model outputs n tokens, and the prior (bias) in the model for tokens is m = (m_1, m_2, .... m_n), and the new prior you want is n=(n_1, n_2, ..._)
Then if the model outputs prediction p = (p_1, ... , p_n) for all tokens, then the new output you are looking for is
You can prove this using Bayes rule + a bit of algebra. Most ML people don't seem to know this trick, but it's super useful for domain adaptation / class upsampling, where you know that the class balance in your training set is different to the one you want to predict on.
We logit to start so we need to go back from logit space -> probability space. Linear transformations in logit space don't preserve the sum = 1 requirement of the corresponding probability so you can't just take the sigmoid of each element individually. This is just a slick way of writing it, you can write it all out without any logits / sigmoids, but it's easier to do the algebra in logit space.
Ah ok, I'm not overly familiar with the problem they're trying to solve really - wouldn't always be necessary depending on your application, but I was talking in broad ML terms, if you need a calibrated prediction at the end you do.
They are related, but not the same. This recalibrating of logits is basically what you would do afterwards, to correct for the fact that you have been importance sampling your replay buffer.
Say the model outputs n tokens, and the prior (bias) in the model for tokens is m = (m_1, m_2, .... m_n), and the new prior you want is n=(n_1, n_2, ..._)
Then if the model outputs prediction p = (p_1, ... , p_n) for all tokens, then the new output you are looking for is
bias_shift(p) = softmax(logit(p) + log(n) - log(m))
You can prove this using Bayes rule + a bit of algebra. Most ML people don't seem to know this trick, but it's super useful for domain adaptation / class upsampling, where you know that the class balance in your training set is different to the one you want to predict on.