1 00:00:00,000 --> 00:00:05,434 We're now going to apply Hessian-free optimization, to the task of modeling 2 00:00:05,434 --> 00:00:10,216 character strings from Wikipedia. So, the idea is, you read a lot of 3 00:00:10,216 --> 00:00:14,056 Wikipedia and then try to predict the next character. 4 00:00:14,056 --> 00:00:19,418 Before we get to see what the model learns, I want to describe why we need 5 00:00:19,418 --> 00:00:25,576 multiplicative connections and how we can implement those multiplicative connections 6 00:00:25,576 --> 00:00:32,394 efficiently in a recurrent neural network. I need to start by explaining why we chose 7 00:00:32,394 --> 00:00:37,893 to model character strings rather than strings of words, which is what you 8 00:00:37,893 --> 00:00:41,533 normally do when you're trying to model language. 9 00:00:41,533 --> 00:00:48,329 The web is composed of character strings. Any learning method that's powerful enough 10 00:00:48,329 --> 00:00:53,637 to understand what's going on in the world by reading the web, ought to find it 11 00:00:53,637 --> 00:00:59,080 trivial to learn which strings make words. As we'll see, this turns out to be true. 12 00:00:59,400 --> 00:01:04,550 So we're going to be very ambitious here. We want something that will read Wikipedia 13 00:01:04,550 --> 00:01:10,646 and understand the world. If we have to pre-process the text in 14 00:01:10,646 --> 00:01:14,234 Wikipedia into words it's going to be a big hassle. 15 00:01:14,234 --> 00:01:18,794 There's all sorts of problems. The first problem is morphemes. 16 00:01:18,794 --> 00:01:23,803 The smallest units of meaning, according to linguists, are morphemes. 17 00:01:23,803 --> 00:01:29,709 So we're going to have to break up a word into these morphemes if we want to deal 18 00:01:29,709 --> 00:01:35,776 with it sensibly. The problem is, it's not quite clear what 19 00:01:35,776 --> 00:01:38,807 morphines are. There's things that are a bit like 20 00:01:38,807 --> 00:01:42,210 morphemes, but that a linguist wouldn't call a morpheme. 21 00:01:42,210 --> 00:01:49,502 So in English, if you take any word that starts with the letters sn, it has a very 22 00:01:49,502 --> 00:01:56,525 high chance of meaning something to do with the lips or nose, particularly the 23 00:01:56,525 --> 00:02:01,747 upper lip or nose. So words like snarl, and sneeze, and snot, 24 00:02:01,747 --> 00:02:06,969 and snog, and snort. There's too many of these words for it 25 00:02:06,969 --> 00:02:11,569 just to be coincidence. Many people say yes but what about snow? 26 00:02:11,569 --> 00:02:16,627 That's got nothing to do with the apple lips or nose. But, ask yourself something, 27 00:02:16,627 --> 00:02:23,897 why is snow such a good word for cocaine? Then there's words that come in several 28 00:02:23,897 --> 00:02:27,042 pieces. So normally, we'd want to treat New York 29 00:02:27,042 --> 00:02:31,058 as one lexical item. But if we're talking about the New York 30 00:02:31,058 --> 00:02:36,278 minster roof, then we might want to treat New and York as two separate lexical 31 00:02:36,278 --> 00:02:40,707 items. And then there's languages like Finnish. 32 00:02:40,707 --> 00:02:46,366 Finnish is an agglutinative language, so it puts together lots of morphemes to make 33 00:02:46,366 --> 00:02:50,320 great big words. So here's an example of a word in Finnish 34 00:02:50,320 --> 00:02:54,480 that takes about five words in English to say the same thing. 35 00:02:56,140 --> 00:03:01,195 I have no idea what this word means. But despite my lack of understanding, it 36 00:03:01,195 --> 00:03:07,247 makes the point. So here's an obvious kind of recurring net 37 00:03:07,247 --> 00:03:10,350 we might use to try and model character strengths. 38 00:03:10,350 --> 00:03:15,587 It has a hidden state and in this case, we're going to use 1500 hidden units. 39 00:03:15,587 --> 00:03:21,613 And, the hidden state dynamics is that the hidden state at time T provides inputs to 40 00:03:21,613 --> 00:03:26,922 determine the hidden state at time T+11. And the character also provides some 41 00:03:26,922 --> 00:03:30,079 inputs. So we add together the effect of the 42 00:03:30,079 --> 00:03:35,603 current character with the previous hidden state to get the new hidden state. 43 00:03:35,603 --> 00:03:40,912 And then when we arrive at a new hidden state, we try and predict the next 44 00:03:40,912 --> 00:03:44,452 character. So, we have a single Softmax over the 83 45 00:03:44,452 --> 00:03:50,167 characters, and we get the hidden state to try and assign high probability to the 46 00:03:50,167 --> 00:03:54,190 correct next character, and low probability to the others. 47 00:03:54,190 --> 00:03:58,957 And we train the whole system by backpropagating from that Softmax, 48 00:03:58,957 --> 00:04:02,619 The low probability of getting the correct character. 49 00:04:02,619 --> 00:04:08,215 We backpropagate that through the hidden to output connections back through the 50 00:04:08,215 --> 00:04:13,397 hidden to character connections, and then back through the hidden to hidden 51 00:04:13,397 --> 00:04:18,580 connections, and so on and all the way back'til the beginning of the string. 52 00:04:20,600 --> 00:04:25,172 It's a lot easier to predict 86 characters than 100,000 words. 53 00:04:25,172 --> 00:04:30,778 So it's easier to use a Softmax at the output, we don't have the problem of a 54 00:04:30,778 --> 00:04:37,569 great, big Softmax. Now, to explain why we didn't use that 55 00:04:37,569 --> 00:04:42,700 kind of recurrent net, but instead used a different kind of net that worked quite a 56 00:04:42,700 --> 00:04:46,929 lot better. You could arrange all possible character 57 00:04:46,929 --> 00:04:51,317 strings into a tree with a branching ratio of 86, in our case. 58 00:04:51,317 --> 00:04:56,712 And what I'm showing here, is a tiny little subtree, of that great big tree. 59 00:04:56,712 --> 00:05:02,611 In fact, this little subtree will occur many times, but with different things that 60 00:05:02,611 --> 00:05:06,280 are represented by that dot, dot, dot before the fix. 61 00:05:06,660 --> 00:05:12,157 So this represents that we had a whole bunch of characters, then we had f and 62 00:05:12,157 --> 00:05:16,155 then i and then x. And now if we get an i, we're going to go 63 00:05:16,155 --> 00:05:19,654 to the left. If we get an e, we're gonna go to the 64 00:05:19,654 --> 00:05:23,652 right, and so on. So each time we get a character, we move 65 00:05:23,652 --> 00:05:31,598 one step down in this tree to a new note. There's exponentially many nodes in the 66 00:05:31,598 --> 00:05:37,056 tree of all character strings of length n. So this is going to be a very big tree. 67 00:05:37,056 --> 00:05:42,234 We couldn't possibly store it all. If we could store it all, what we'd like 68 00:05:42,234 --> 00:05:45,802 to do is put a probability on each of those arrows. 69 00:05:45,802 --> 00:05:51,120 And that will be the probability of producing that letter, given the context 70 00:05:51,120 --> 00:05:57,121 of the node. In an RNN, we try and deal with the fact 71 00:05:57,121 --> 00:06:02,711 that the full tree is enormous by using a hidden state vector to represent each of 72 00:06:02,711 --> 00:06:06,280 these nodes. So now, what the next character has to do 73 00:06:06,280 --> 00:06:11,129 is take the hidden state vector that's representing the whole string of 74 00:06:11,129 --> 00:06:16,584 characters followed by fix and operate on the hidden state vector to produce the 75 00:06:16,584 --> 00:06:21,096 appropriate new hidden state vector if the next character was an i. 76 00:06:21,096 --> 00:06:26,484 So when you see an i, you want to turn the hidden state vector into a new hidden 77 00:06:26,484 --> 00:06:32,194 state vector. A nice thing about implementing these 78 00:06:32,194 --> 00:06:37,731 nodes in this character tree by using the hidden state of recurrent neural network, 79 00:06:37,731 --> 00:06:44,210 is that we can share a lot of structure. For example, by the time we arrive at that 80 00:06:44,210 --> 00:06:49,171 node, that says f, i, x, we may have decided that it's probably a verb. 81 00:06:49,171 --> 00:06:55,021 And if it's a verb, then i is quite likely because of the ending i, n, g. And that 82 00:06:55,021 --> 00:07:00,797 knowledge that i is quite likely with a verb, can be shared with lots of other 83 00:07:00,797 --> 00:07:06,203 nodes that don't have f I x in. So we can get I to operate on the part of 84 00:07:06,203 --> 00:07:11,979 the state that represents that it's a verb, and that can be shared between all 85 00:07:11,979 --> 00:07:17,654 the verbs. Notice that, it's really the conjunction 86 00:07:17,654 --> 00:07:23,069 of the current state we're at and the character that determines where we want to 87 00:07:23,069 --> 00:07:26,365 go. We don't want i, to give us a state that's 88 00:07:26,365 --> 00:07:29,830 expecting to get an n next if it wasn't a verb. 89 00:07:29,830 --> 00:07:34,695 So, we don't want to say that i tends to make you expect an n next. 90 00:07:34,695 --> 00:07:40,297 We really want to say, if you already think it's a verb, then when you see an, 91 00:07:40,297 --> 00:07:45,530 i, you should expect an n next. It's the conjunction of the fact that we 92 00:07:45,530 --> 00:07:51,427 think it's a verb, and that we saw an i, that gets us into this state labeled f, i, 93 00:07:51,427 --> 00:07:59,728 x, i, that's expecting to see an n. So we're going to try and capture that by 94 00:07:59,728 --> 00:08:04,527 using multiplicative connections. Instead of using the character inputs to 95 00:08:04,527 --> 00:08:09,649 their current net to give extra additive input to the hidden units, we're going to 96 00:08:09,649 --> 00:08:14,253 use those characters to swap in a whole hidden-to-hidden weight matrix. 97 00:08:14,253 --> 00:08:17,820 The character is going to determine the transition matrix. 98 00:08:19,060 --> 00:08:25,749 Now, if we did that in the naive way, we'd have each of the 86 characters to find a 99 00:08:25,749 --> 00:08:30,540 1500x1500 matrix and that would be a lot of parameters. 100 00:08:30,940 --> 00:08:36,249 If we have that many parameters, then that's likely to overfit, unless we run it 101 00:08:36,249 --> 00:08:39,920 on huge amount of text, for which we might not have time. 102 00:08:41,720 --> 00:08:46,782 So the question is, can we achieve this kind of multiplicative interaction, where 103 00:08:46,782 --> 00:08:51,339 the character determines the hidden-to-hidden weight matrix using many 104 00:08:51,339 --> 00:08:56,022 fewer parameters, by making use of the fact that characters have things in 105 00:08:56,022 --> 00:08:58,870 common. For example, all of the digits are all 106 00:08:58,870 --> 00:09:04,060 quite similar to each other in the way in which they make the hidden state evolve. 107 00:09:05,480 --> 00:09:10,403 So, we want to have a different transition matrix for each of those 86 characters, 108 00:09:10,403 --> 00:09:15,266 but we want those 86 character-specific weight matrices to share parameters, and 109 00:09:15,266 --> 00:09:20,250 that's a reasonable thing to do cuz we know that characters eight to nine should 110 00:09:20,250 --> 00:09:26,862 have very similar transition matrices. So here's how we're going o do it. 111 00:09:26,862 --> 00:09:31,379 We're going to have things called factors, and they're going to be denoted by this 112 00:09:31,379 --> 00:09:37,841 little triangle with an f above it. What that factor means is that Group a and 113 00:09:37,841 --> 00:09:43,380 Group b interact multiplicatively to provide input to Group c. 114 00:09:45,220 --> 00:09:50,612 So, what each factor does is it first computes a weighted sum for each of its 115 00:09:50,612 --> 00:09:54,533 two input groups. So, we take the vector state of Group a, 116 00:09:54,533 --> 00:10:00,345 which I just call a, and we multiply that by the weight sum connections coming into 117 00:10:00,345 --> 00:10:03,636 the factor. In other words, we take the scale of 118 00:10:03,636 --> 00:10:09,098 product of the vector a and the weight vector u, and that gives us a number at 119 00:10:09,098 --> 00:10:14,440 the left hand vertex of that triangle. Similarly, we take the vector stage of 120 00:10:14,440 --> 00:10:19,492 Group b and we multiply it by the weight factor w, and we get another number off 121 00:10:19,492 --> 00:10:25,652 the bottom vertex of the triangle. We now multiply those two numbers together 122 00:10:25,652 --> 00:10:31,466 and that gives us a number or scalar. And we use that scalar to scale the 123 00:10:31,466 --> 00:10:36,069 outgoing weights v in order to provide input for Group c. 124 00:10:36,069 --> 00:10:42,368 So the input to Group c is just the product of the two numbers that come into 125 00:10:42,368 --> 00:10:47,860 the two vertices of the triangle, times the outgoing weight factor v. 126 00:10:49,260 --> 00:10:56,210 We can write that as an equation. The input that factor f provides to Group 127 00:10:56,210 --> 00:11:03,624 c, so its vector input to Group c, is a scalar input to f from Group b, that's got 128 00:11:03,624 --> 00:11:11,131 by multiplying the state of Group b by the weights w, f times a scalar input to f 129 00:11:11,131 --> 00:11:18,360 from Group a, that's got by multiplying that state of Group a by the weights u. 130 00:11:19,180 --> 00:11:24,695 We then take the product of those two scalars and multiple the weight vector vf 131 00:11:24,695 --> 00:11:28,969 by that, and that's the input that the factor gives to Group c. 132 00:11:28,969 --> 00:11:33,520 Then, of course, we're going to have a whole bunch of those factors. 133 00:11:35,880 --> 00:11:40,548 There's another way we can think about these factors that gives more insight into 134 00:11:40,548 --> 00:11:46,263 what's going on. Each of the factors actually defines a 135 00:11:46,263 --> 00:11:52,192 very simple kind of transition matrix. It's a transition matrix that has rank 136 00:11:52,192 --> 00:11:56,414 one. So, the equation we had on the previous 137 00:11:56,414 --> 00:12:01,978 slide treats a factor as computing two scalar products, multiplying them 138 00:12:01,978 --> 00:12:07,080 together, and then using that as a weight on the outgoing vector v. 139 00:12:08,020 --> 00:12:14,286 We can rearrange that equation. So that we get one scalar product, and 140 00:12:14,286 --> 00:12:20,860 then we rearrange the last bit so that now, we take the outer product of the 141 00:12:20,860 --> 00:12:27,835 weight vector u and the weight vector v, And that gives us a matrix. And the scalar 142 00:12:27,835 --> 00:12:33,751 product that would be computed by multiplying b by w is just a coefficient 143 00:12:33,751 --> 00:12:37,867 on that matrix. So we get a scalar coefficient, we 144 00:12:37,867 --> 00:12:44,983 multiply a rank one matrix by that scalar coefficient to give us that scalar matrix. 145 00:12:44,983 --> 00:12:51,842 And then, we multiply the current hidden state a by this matrix to determine the 146 00:12:51,842 --> 00:12:56,130 input the factor f gives to the next hidden state. 147 00:12:56,130 --> 00:13:02,458 If we sum that up over all the factors, the total input to Group c is just a sum 148 00:13:02,458 --> 00:13:08,866 over all factors of a scaler times a rank one matrix, and that sum is a great big 149 00:13:08,866 --> 00:13:15,273 matrix, that's the transition matrix, and it gets multiplied by the current hidden 150 00:13:15,273 --> 00:13:21,392 state to produce the new hidden state. So, we can see that we synthesized the 151 00:13:21,392 --> 00:13:25,779 transition matrix, Actually, these rank one matrices provided 152 00:13:25,779 --> 00:13:31,316 by each factor. And what the current character in Group b has done is, is it's 153 00:13:31,316 --> 00:13:35,343 determine the weight on each of these rank one matrices. 154 00:13:35,343 --> 00:13:40,161 So, bw times w determines the scalar weight, the scalar coefficient to put on 155 00:13:40,161 --> 00:13:45,338 each of the matrices, actually which we are going to compose this great big 156 00:13:45,338 --> 00:13:51,643 character specific right matrix. So here's a picture of the whole system, 157 00:13:51,643 --> 00:13:57,227 we have a number of factors, in fact we'll have about 1500 factors. 158 00:13:57,227 --> 00:14:04,163 And the character input is different in that only one of those is active so there 159 00:14:04,163 --> 00:14:07,801 will only be one relevant weight at a time. 160 00:14:07,801 --> 00:14:14,653 And that weight from the current character k, which called w, kf, is the gain that's 161 00:14:14,653 --> 00:14:20,660 used on the rank one matrix got by taking the outer product of u and v. 162 00:14:21,220 --> 00:14:27,071 So the character determines again W, kf, you multiply the rank one matrix uv by 163 00:14:27,071 --> 00:14:32,779 that gain, you add together the scale matrices for all the different factors and 164 00:14:32,779 --> 00:14:34,920 that's your transition matrix.