1 00:00:00,384 --> 00:00:06,135 In this video, I'll describe the first way we discovered for getting Sigmoid Belief 2 00:00:06,135 --> 00:00:12,847 Nets to learn efficiently. It's called the wake-sleep algorithm and 3 00:00:12,847 --> 00:00:15,714 it should not be confused with Boltzmann machines. 4 00:00:15,714 --> 00:00:20,168 They have two phases, a positive and a negative phase, that could plausibly be 5 00:00:20,168 --> 00:00:25,114 related to wake and sleep. But the wake-sleep algorithm is a very 6 00:00:25,114 --> 00:00:30,080 different kind of learning, mainly because it's for directed graphical models like 7 00:00:30,080 --> 00:00:33,751 Sigmoid Belief Nets, rather than for undirected graphical models 8 00:00:33,751 --> 00:00:35,715 like Boltzmann machines. 9 00:00:35,715 --> 00:00:38,412 The ideas behind the wake-sleep algorithm 10 00:00:38,412 --> 00:00:43,855 led to a whole new area of machine learning called variational learning, 11 00:00:43,855 --> 00:00:47,215 which didn't take off until the late 1990s, 12 00:00:47,215 --> 00:00:53,150 despite early examples like the wake-sleep algorithm, and is now one of the main ways 13 00:00:53,150 --> 00:00:57,952 of learning complicated graphical models in machine learning. 14 00:00:57,952 --> 00:01:01,921 The basic idea behind these variational methods sounds crazy. 15 00:01:01,921 --> 00:01:07,332 The idea is that since it's hard to compute the correct posterior distribution, 16 00:01:07,332 --> 00:01:10,638 we'll compute some cheap approximation to it. 17 00:01:10,638 --> 00:01:13,583 And then, we'll do maximum likelihood learning anyway. 18 00:01:13,583 --> 00:01:17,258 That is, we'll apply the learning rule that would be correct, 19 00:01:17,258 --> 00:01:20,083 if we'd got a sample from the true posterior, 20 00:01:20,083 --> 00:01:23,068 and hope that it works, even though we haven't. 21 00:01:23,068 --> 00:01:28,519 Now, you could reasonably expect this to be a disaster, 22 00:01:28,519 --> 00:01:31,561 but actually the learning comes to your rescue. 23 00:01:31,561 --> 00:01:35,098 So, if you look at what's driving the weights during the learning, 24 00:01:35,098 --> 00:01:37,540 when you use an approximate posterior, 25 00:01:37,540 --> 00:01:41,117 there are actually two terms driving the weights. 26 00:01:41,117 --> 00:01:45,688 One term is driving them to get a better model of the data. 27 00:01:45,688 --> 00:01:49,875 That is, to make the Sigmoid Belief Net more likely to generate the observed data 28 00:01:49,875 --> 00:01:53,719 in the training center. But there's another term that's added to that, 29 00:01:53,719 --> 00:01:56,000 that's actually driving the weights 30 00:01:56,000 --> 00:02:02,207 towards sets of weights for which the approximate posterior it's using is a good fit 31 00:02:02,207 --> 00:02:06,154 to the real posterior. It does this by manipulating the real 32 00:02:06,154 --> 00:02:10,196 posterior to try to make it fit the approximate posterior. 33 00:02:10,196 --> 00:02:15,931 It's because of this effect, the variational learning of these models works 34 00:02:15,931 --> 00:02:20,559 quite nicely. Back in the mid 90s,' when we first came 35 00:02:20,559 --> 00:02:23,996 up with it, we thought this was an interesting new theory of how the brain 36 00:02:23,996 --> 00:02:26,827 might learn. That idea has been taken up since by Karl 37 00:02:26,827 --> 00:02:32,875 Friston, who strongly believes this is what's going on in real neural learning. 38 00:02:32,875 --> 00:02:38,897 So, we're now going to look in more detail at how we can use an approximation to the 39 00:02:38,897 --> 00:02:46,719 posterior distribution for learning. To summarize, it's hard to learn 40 00:02:46,719 --> 00:02:51,794 complicated models like Sigmoid Belief Nets because it's hard to get samples from 41 00:02:51,794 --> 00:02:56,622 the true posterior distribution over hidden configurations, when given a data vector. 42 00:02:56,622 --> 00:03:01,162 And it's hard even to get a sample from that posterior. 43 00:03:01,162 --> 00:03:04,564 That is, it's hard to get an unbiased sample. 44 00:03:04,564 --> 00:03:08,668 So, the crazy idea is that we're going to 45 00:03:08,668 --> 00:03:14,308 use samples from some other distribution and hope that the learning will still work. 46 00:03:14,308 --> 00:03:19,175 And as we'll see, that turns out to be true for Sigmoid Belief Nets. 47 00:03:19,175 --> 00:03:22,666 So, the distribution that we're going to 48 00:03:22,666 --> 00:03:26,622 use is a distribution that ignores explaining away. 49 00:03:26,622 --> 00:03:31,355 We're going to assume (wrongly) that the posterior over hidden configurations 50 00:03:31,355 --> 00:03:37,053 factorizes into a product of distributions for each separate hidden unit. 51 00:03:37,053 --> 00:03:39,271 In other words, we're going to assume that 52 00:03:39,271 --> 00:03:44,317 given the data, the units in each hidden layer are independent of one another, 53 00:03:44,317 --> 00:03:46,799 as they are in a Restricted Boltzmann machine. 54 00:03:46,799 --> 00:03:50,256 But in a Restricted Boltzmann machine, this is correct. 55 00:03:50,256 --> 00:03:53,160 Whereas, in a Sigmoid Belief Net, it's wrong. 56 00:03:53,160 --> 00:03:56,618 So, let's quickly look at what a factorial distribution is. 57 00:03:56,618 --> 00:04:01,181 In a factorial distribution, the probability of a whole vector is just the 58 00:04:01,181 --> 00:04:04,527 product of the probabilities of its individual terms. 59 00:04:04,527 --> 00:04:09,815 So, suppose we have three hidden units in the layer and they have probabilities of 60 00:04:09,815 --> 00:04:16,764 being wrong of 0.3, 0.6, and 0.8. If we want to compute the probability of the 61 00:04:16,764 --> 00:04:23,896 hidden layer having the state (1, 0, 1), We compute that by multiplying 0.3 62 00:04:23,896 --> 00:04:32,596 by (1 - 0.6), by 0.8. So, the probability of a configuration of 63 00:04:32,596 --> 00:04:36,884 the hidden layer is just the product of the individual probabilities. 64 00:04:36,884 --> 00:04:40,979 That's why it's called factorial. In general, the distribution of binary 65 00:04:40,979 --> 00:04:44,647 vectors of length n will have two to the n degrees of freedom. 66 00:04:44,647 --> 00:04:50,947 Actually, it's only two to the n minus one because the probabilities must add to one. 67 00:04:50,947 --> 00:04:56,242 A factorial distribution, by contrast, only has n degrees of freedom. 68 00:04:56,242 --> 00:05:00,911 It's a much simpler beast. So now, I'm going to describe the 69 00:05:00,911 --> 00:05:05,355 wake-sleep algorithm that makes use of the idea of using the wrong distribution. 70 00:05:05,355 --> 00:05:11,952 And in this algorithm, we have a neural net that has two different sets of weights. 71 00:05:13,982 --> 00:05:17,217 It's really a generative model and so, the 72 00:05:17,217 --> 00:05:22,756 weights shown in green for generative are the weights of the model. 73 00:05:22,756 --> 00:05:27,507 Those are the weights that define the probability distribution over data vectors. 74 00:05:29,383 --> 00:05:34,902 We've got some extra weights. The weights shown in red, for recognition weights, 75 00:05:34,902 --> 00:05:37,199 and these are weights that are used for 76 00:05:37,199 --> 00:05:40,618 approximately getting the posterior distribution. 77 00:05:40,618 --> 00:05:44,827 That is, we're going to use these weights to get a factorial distribution at each 78 00:05:44,827 --> 00:05:49,353 hidden layer that approximates the posterior, but not very well. 79 00:05:51,463 --> 00:05:54,814 So, in this algorithm, there's a wake phase. 80 00:05:54,814 --> 00:06:00,359 In the wake phase, you put data in at the visible layer, the bottom, and you do a 81 00:06:00,359 --> 00:06:04,763 forward pass through the network using the recognition weights. 82 00:06:04,763 --> 00:06:09,894 And in each hidden layer, you make a stochastic binary decision for each hidden 83 00:06:09,894 --> 00:06:14,458 unit independently, about whether it should be on or off. 84 00:06:14,458 --> 00:06:19,812 So, the forward pass gets us stochastic binary states for all of the hidden units. 85 00:06:20,858 --> 00:06:26,917 Then, once we have those stochastic binary states, we treat that as though it was a 86 00:06:26,917 --> 00:06:31,537 sample from the true posterior distribution given the data and we do 87 00:06:31,537 --> 00:06:35,121 maximum likelihood learning. But what we're doing maximum likelihood 88 00:06:35,121 --> 00:06:39,184 learning for is not the recognition weights that we just used to get the 89 00:06:39,184 --> 00:06:44,408 approximate sample. It's the generative weights that define our models. 90 00:06:44,408 --> 00:06:47,272 So, you drive the system in the forward 91 00:06:47,272 --> 00:06:51,848 pass with the recognition weights, but you learn the generative weights. 92 00:06:53,478 --> 00:06:59,794 In the sleep phase, you do the opposite. You drive the system with the generative weights. 93 00:06:59,794 --> 00:07:02,615 That is, you start with a random vector of 94 00:07:02,615 --> 00:07:06,440 the top hidden layer. You generate the binary states of those 95 00:07:06,440 --> 00:07:09,886 hidden units from their prior in which they're independent. 96 00:07:09,886 --> 00:07:16,837 And then, you go down through the system, generating state for each layer at a time. 97 00:07:16,837 --> 00:07:19,802 And here you're using the generative model correctly. 98 00:07:19,802 --> 00:07:23,240 That's how the generative model says you want to generate data. 99 00:07:23,240 --> 00:07:27,412 And so, you can generate an unbiased sample from the model. 100 00:07:28,596 --> 00:07:34,765 Having used the generative weights to generate an unbiased sample, you then say, 101 00:07:34,765 --> 00:07:39,001 let's see if we can recover the hidden states from the data. 102 00:07:39,001 --> 00:07:43,856 Well, let's see if we can recover the hidden states that layer h2 from the 103 00:07:43,856 --> 00:07:49,615 hidden states at layer h1. So, you train recognition weights, to try 104 00:07:49,615 --> 00:07:56,112 and recover the hidden states that actually generated the states in the layer below. 105 00:07:57,712 --> 00:07:59,880 So, it's just the opposite of the weight phase. 106 00:07:59,880 --> 00:08:01,781 We're now using the generative weights to 107 00:08:01,781 --> 00:08:05,795 drive the system and we're learning the recognition weights. 108 00:08:06,718 --> 00:08:11,447 It turns out that if you start with random weights and you alternate between wake 109 00:08:11,447 --> 00:08:15,625 phases and sleep phases it learns a pretty good model. 110 00:08:15,625 --> 00:08:22,124 There are flaws in this algorithm. The first flaw is a rather minor one 111 00:08:22,124 --> 00:08:27,383 which is, the recognition weights are learning to invert the generative model. 112 00:08:27,383 --> 00:08:31,413 But at the beginning of learning, they're learning to invert the generative model in 113 00:08:31,413 --> 00:08:33,719 parts of the space where there isn't any data. 114 00:08:33,719 --> 00:08:37,353 Because when you generate from the model, you're generating stuff that looks very 115 00:08:37,353 --> 00:08:41,028 different from the real data, because the weights aren't any good. 116 00:08:41,028 --> 00:08:44,512 That's a waste, but it's not a big problem. 117 00:08:44,512 --> 00:08:50,958 The serious problem with this algorithm is that the recognition weights not only 118 00:08:50,958 --> 00:08:53,470 don't follow the gradient of the log probability of the data, 119 00:08:53,470 --> 00:08:58,686 They don't even follow the gradient of the variational bound on this probability. 120 00:08:58,686 --> 00:09:03,318 And because they're not following the right gradient, we get incorrect mode averaging, 121 00:09:03,318 --> 00:09:05,832 which I'll explain in the next slide. 122 00:09:05,832 --> 00:09:08,671 A final problem is that we know that the 123 00:09:08,671 --> 00:09:13,972 true posterior over the top hidden layer is bound to be far from independent 124 00:09:13,972 --> 00:09:18,922 because of explaining away effects. And yet, we're forced to approximate it 125 00:09:18,922 --> 00:09:22,004 with a distribution that assumes independence. 126 00:09:22,004 --> 00:09:27,723 This independence approximation might not be so bad for intermediate hidden layers, 127 00:09:27,723 --> 00:09:33,676 because if we're lucky, the explaining away effects that come from below will be 128 00:09:33,676 --> 00:09:37,768 partially canceled out by prior effects that come from above. 129 00:09:37,768 --> 00:09:44,670 You'll see that in much more detail later. Despite all these problems, Karl Friston 130 00:09:44,670 --> 00:09:47,828 thinks this is how the brain works. When we initially came up with the 131 00:09:47,828 --> 00:09:51,166 algorithm, we thought it was an interesting new theory of the brain. 132 00:09:51,166 --> 00:09:58,699 I currently believe that it's got too many problems to be how the brain works and 133 00:09:58,699 --> 00:10:00,914 that we'll find better algorithms. 134 00:10:02,252 --> 00:10:05,753 So now let me explain mode averaging, using the little model 135 00:10:05,753 --> 00:10:09,058 with the earthquake and the truck that we saw before. 136 00:10:09,058 --> 00:10:15,342 Suppose we run the sleep phase, and we generate data from this model. 137 00:10:15,342 --> 00:10:19,968 Most of the time, those top two units would be off 138 00:10:19,968 --> 00:10:22,866 because they are very unlikely to turn on under their prior, 139 00:10:22,866 --> 00:10:33,498 and, because they are off, the visible unit will be firmly off, because its bias is -20. 140 00:10:35,188 --> 00:10:42,512 Just occassionally, one time in about e to the -10, one of the two top units will turn on 141 00:10:43,604 --> 00:10:47,020 and it will be equally often the left one or the right one. 142 00:10:47,636 --> 00:10:50,565 When that unit turns on, there's a probability of a half 143 00:10:50,565 --> 00:10:54,099 that the visible unit will turn on. 144 00:10:54,099 --> 00:10:57,909 So, if you think about the occassions on which the visible unit turns on, 145 00:10:57,909 --> 00:11:02,769 half those occassions have the left-hand hidden unit on, 146 00:11:02,769 --> 00:11:06,838 the other half of those occassions have the right-hand hidden unit on 147 00:11:06,838 --> 00:11:11,430 and almost none of those occassions have neither or both units on. 148 00:11:11,430 --> 00:11:16,995 So now think what the learning would do for the recognition weights. 149 00:11:16,995 --> 00:11:20,951 Half the time we'll have a 1 on the visible layer, 150 00:11:20,951 --> 00:11:24,792 the leftmost unit will be on at the top, 151 00:11:24,792 --> 00:11:32,504 so we'll actually learn to predict that that's on with a probability of 0.5, and the same for the right unit. 152 00:11:32,504 --> 00:11:37,282 So the recognition units will learn to produce a factorial distribution 153 00:11:37,282 --> 00:11:41,022 over the hidden layer, of (0.5, 0.5) 154 00:11:41,022 --> 00:11:46,650 and that factorial distribution puts a quarter of its mass on the configuration (1,1) 155 00:11:46,650 --> 00:11:50,320 and another quarter of its mass on the configuration (0,0) 156 00:11:50,320 --> 00:11:56,257 and both of those are extremely unlikely configurations given that the visible unit was on. 157 00:11:58,133 --> 00:12:00,982 It would have been better just to pick one mode, 158 00:12:00,982 --> 00:12:06,488 that is, it would have been better for the visible unit just to go for truck, or just to go for earthquake. 159 00:12:07,856 --> 00:12:10,696 That's the best recognition model you can have, 160 00:12:10,696 --> 00:12:16,713 that's the best recognition model you can have if you're forced to have a factorial model. 161 00:12:17,851 --> 00:12:21,498 So even though the hidden configurations we're dealing with 162 00:12:21,498 --> 00:12:24,299 are best represented as the corners of a square 163 00:12:24,299 --> 00:12:28,940 actually show it as if it's a one-dimensional continuous value, 164 00:12:28,940 --> 00:12:38,250 and the true posterior is bimodal. It's focused on (1,0) or (0,1), that's shown in black. 165 00:12:38,250 --> 00:12:43,552 The approximation window, if you use the sleep phase of the wake-sleep algorithm, 166 00:12:43,552 --> 00:12:49,312 is the red curve, which gives all four states of the hidden units equal probability 167 00:12:51,157 --> 00:12:57,251 and the best solution would be to pick one of these states, and give it all the probability mass. 168 00:12:57,251 --> 00:13:04,812 That's the best solution, because in variational learning we're manipulating the true posterior 169 00:13:04,812 --> 00:13:07,999 to make it fit the approximation we're using. 170 00:13:07,999 --> 00:13:11,999 Normally, in learning we'll manipulate an approximation to fit the true thing, but here it's backwards.