Hacker News new | past | comments | ask | show | jobs | submit login
New kind of recurrent neural network using attention (github.com)
155 points by jostmey on Mar 8, 2017 | hide | past | web | favorite | 42 comments

Author here. I'd be happy to address any questions. Constructive criticism is always welcome.

Sounds promising! A couple of questions after a quick read, but sure I have more later: Is it faster than LSTM for training / inference? and, are you planning to use it for more complex problems as translation or image to caption?

The RWA model requires fewer training steps than a LSTM model on the tasks I tested (https://arxiv.org/pdf/1703.01253.pdf). Even better, the RWA model uses over 25% fewer parameters (because it doesn't have lots of gates like a LSTM model). So the RWA uses less computer memory than a LSTM model.

To answer your second question: Yes I do plan on testing it on more challenging problems! I just need enough time and the GPU rig.

I was wondering the time per step, I am aware it converges faster.

Let me know when you have more results! I want to start doing things with text soon and I would love to try RWA

The RWA model is described in equations (7) of this paper: https://arxiv.org/pdf/1703.01253.pdf

I think you will see that the cost per training set is just as cheap if not cheaper than a LSTM model. On my machine, the RWA model completes each training step faster than a LSTM model. I don't know if that will hold for GPUs.

Yes, give us wall time please

So in that case it has the same amount of parameters as GRU - did you compare with that?

No, I have not tried the GRU model

Let me make sure I understand: rather than using gates, each step explicitly takes a weight average of all past steps using attention.

What is the training speed of this network? Computation seems to scale at N^2 rather than N.

EDIT: The attention mechanism is nothing more than a weighted average. The weighted average is computed as a running average by saving the numerator and denominator terms at each step.

I hope the description in the paper is clear. You can follow the ARXIV link in the README. Skip straight to section 2 for the details of the model

The stuff in Appendix C worries me a little. Dividing the numerator and denominator indeed does not change the output at timestep t. However, if you use the new, smaller numerator and denominator at later timesteps, later outputs will be affected.

Suppose, for example, that you divide the numerator and denominator by 2, giving n/2 and d/2. In the next timestep, suppose a<0. You now do n=n/2+za and d=d/2+a. What you really ought to be doing is (n+za)/2 and (d+a)/2. In other words, once you scale n and d by some factor, all later timesteps should have that same factor thrown in (in order to preserve the outputs from the theoretical model).

If you do things the way I suspect, then any positive value of `a` essentially gives more priority to later timesteps (since their numerators and denominators are not scaled down by the theoretically necessary amount). This seems to defeat most of the purpose of the model.

Edit: by the way, I tried looking through your code to figure out what it was actually doing. However, it does seem like you are doing what I thought. If I am correct, then your numerical stability code does actually affect the model's output.

Oh no! I think you may be right. It is possible to correct the model for numerical stability in a manner that doesn't impact the output, but my approach fails at that. The correction from the previous step needs to be saved and carried over to the next step, just like the numerator and denominator are saved and carried over.

I think I will have to re-run all the results once I fix the code. This is a bummer

-- Thanks for finding this! Jared

Does this scale to longer sequences? It looks like it needs to have O(n^2) weights for sequences of n entries. Or am I misunderstanding?

My team would like to do training on long text documents to do analysis on them, but they can be thousands of words long.

In my opinion, that's the beauty of the approach. The attention mechanism reuses the same weights for each symbol in the sequence. The attention mechanism is essentially a weighted average that filters which pieces of the sequence are feed through each processing step. So no, it does not need O(n^2) weights.

My guess is that the RWA model would be well suited for natural language processing. I am biased. Of course, you would need to first convert each word into a vector representation (i.e. word2vec). You might also want to add a decay term in the weighted average to force the model to focus on the recent past and not the "deep" past.

It still scales O(N^2) in terms of computation though right?

If you save numerator_t and denominator_t from a previous computation, then h_t = f( (numerator_(t-1) + z + a) / (denominator_(t-1) + a) ). So it should be linear with t?

(where z and a are z(x_i, h_(i-1)), etc)

Yes, I do save the numerator and denominator.

The model I used is: h_t = f( (numerator_(t-1) + z x e^a) / (denominator_(t-1) + e^a) )

Equations (7) in the arxiv link (here it is again: https://arxiv.org/pdf/1703.01253.pdf) provide the update equations. When implementing the equations, you sometimes have to scale both the numerator and denominator back by a constant factor to avoid an overflow error

I don't follow the intuition behind using an exponentiation of a as in e^a. You refer to this as a "context model" in your paper. Could you please elaborate? Thanks!

The name "concept model" is borrowed from one of the papers I cited. The concept model is just "a(x_i,h_{i-1})". It decides what is important. You can sort of think of it like a filter or gate. When the concept model returns a large value, creating a large exponent, the information encoded by "z(x_i,h_{i-1})" dominates the weighted average.

Could you include a little more describing the baselines you sought to beat in the readme? "Tasks" and "LSTM network" sounds vague to me.

I will try to update the README. The results are described in the arXiv link. I hope the paper is clear. Here's the link: https://arxiv.org/pdf/1703.01253.pdf

Is RWA appropriate for tasks other than classification?

The method I presented is for classification. It is not a generative model, although there's no reason why it can't be incorporated into a generative model.

The RWA model could be used to perform sequence to sequence mappings like in natural language translation. It could also be used to take a seed value and generate a sequence, like automated composition of music. These things have not yet been tried.

Doesn't the memory impact of storing the hidden states scale with O(gamma function)?

What do you mean by O(gamma function)?

The only gamma function I can think of it https://en.wikipedia.org/wiki/Gamma_function , which doesn't make much sense in context.

Perhaps he meant to say O(N!), which would be horrendous scaling.

My bad, it scales as O(n). I figured it out lol

How do GRU cells do on these problems?

I have not tested GRU cells on these problems. The GRU cells are not too different from LSTM cells--both models make heavy use of gating terms. As such, I don't expect a GRU model to do much better than a LSTM. But I cannot say for certain, having not tried it.

Link to the actual paper: https://arxiv.org/pdf/1703.01253.pdf


link to the abstract

Cool, I feel like the days of the LSTM as the de facto strong baseline is ending as these new sequence models come out.

how is this any different than this work(see figure 1)?


I've only skimmed it, so what I am about to say could be incorrect. The big difference between the models is how the attention mechanism is computed. The attention mechanism is just a weighted average. The weighted average can be computed as a running (or moving) average by saving the numerator and denominator at each step. This is what the RWA model is doing. I think the LSTM-N model, referenced in the paper you give, tries to brute force the calculation of the attention mechanism. Again, I could be wrong. I've only skimmed the paper you speak of

I'm unclear about why z needs to be parameterized in terms of two functions, one unbounded, u(x), and a one bounded, tanh(g(x,h)). Shouldn't the squashing function f() prevent explosion anyway? And how would this parameterization prevent vanishing gradients?

Good question!

In retrospect, I should have explained that section differently. I will probably revise that section of the text (that's what I love about arxiv).

We don't want the recurrent terms "h_{t-1}" to dominate the output of "z". But we don't want to get eliminate the recurrent terms because without "h_{t-1}" the output of "z" will not be sequence dependent. So the "tanh" term is included to allow the recurrent terms to change the sign of "z".

I don't know if that helps.


Ah, now it's much more clear. By the way, have you tested the model without parameterizing z the way you did? Different z's might be worth further exploration.

Yes, different models for "z" are worth exploring! You should try different forms and see what works!

Surely the potential issue with this compared to a LSTM+ model is that the average function will have insufficient capacity to capture even comparatively common exceptions in a sequence?

pretty cool. i'm going to try char-rnn with rwa some time

Guidelines | FAQ | Support | API | Security | Lists | Bookmarklet | Legal | Apply to YC | Contact