1 00:00:00,000 --> 00:00:06,525 In this video, I'll go into more detail about how we can speed up the Boltzmann 2 00:00:06,525 --> 00:00:11,548 Machine Learning Algorithm by using cleverer ways of keeping Markov chains 3 00:00:11,548 --> 00:00:16,355 near the equilibrium distribution, or by using what are called mean field methods. 4 00:00:16,355 --> 00:00:20,686 The material is quite advanced and so it's not really part of the course. 5 00:00:20,686 --> 00:00:24,484 There won't be any quizzes on it and it's not on the final test. 6 00:00:24,484 --> 00:00:28,697 You can safely skip this video. It's included for people who are really 7 00:00:28,697 --> 00:00:32,377 interested in how to get deep Boltzmann machines to work well. 8 00:00:32,377 --> 00:00:37,005 There are better ways of collecting the statistics than the method that Terry 9 00:00:37,005 --> 00:00:42,477 Snofsky and I originally came up with. If we start from a random state, it may 10 00:00:42,477 --> 00:00:45,400 take a long time to reach thermal equilibrium. 11 00:00:46,160 --> 00:00:51,215 Also, there's no easy tests for whether you've reached thermal equilibrium, so we 12 00:00:51,215 --> 00:00:57,755 don't know how long we need to run for. So, the idea is why not start from 13 00:00:57,755 --> 00:01:02,545 whatever state you ended up in last time you saw that particular data vector? 14 00:01:02,545 --> 00:01:07,522 So, we remember the interpretation of the data vector in the hidden units, and we 15 00:01:07,522 --> 00:01:12,747 start from there. This stored state, the interpretation of 16 00:01:12,747 --> 00:01:21,308 the data vector, is called a particle. Using particles that persist gives us a 17 00:01:21,308 --> 00:01:27,797 warm start and it has a big advantage. If we were at equilibrium before and we 18 00:01:27,797 --> 00:01:33,258 only updated the weights a little bit, it'll only take a few updates of the units 19 00:01:33,258 --> 00:01:36,360 in a particle to bring it back to equilibrium. 20 00:01:37,200 --> 00:01:42,403 We can use particles for both the positive phase, where we have a clamp data vector, 21 00:01:42,403 --> 00:01:45,640 and for the negative phase, when nothing is clamped. 22 00:01:47,880 --> 00:01:53,509 So, here's the method for directing statistics introduced by Radford Neal in 23 00:01:53,509 --> 00:01:57,273 1992. In the positive phase, you have a set of 24 00:01:57,273 --> 00:02:02,752 data specific particles, one or a few per training case. And each particle has a 25 00:02:02,752 --> 00:02:08,646 current value that's the configuration of the hidden units plus which data vector it 26 00:02:08,646 --> 00:02:14,139 goes with. You sequentially update all the hidden 27 00:02:14,139 --> 00:02:18,780 units a few times in each particle with the relevant data vector clamped. 28 00:02:19,620 --> 00:02:25,873 And then, for every connected pair of units, you average the probability of the 29 00:02:25,873 --> 00:02:29,240 two units being on over all these particles. 30 00:02:31,180 --> 00:02:35,512 In the negative phase, you keep a set of fantasy particles. 31 00:02:35,512 --> 00:02:41,565 That is, these are global configurations. And again, after each weight update, you 32 00:02:41,565 --> 00:02:46,671 sequentially update all the units in each fantasy particle a few times. 33 00:02:46,671 --> 00:02:49,980 Now, you're updating the visible units as well. 34 00:02:50,940 --> 00:02:56,648 And for every connected pair of units, your average, SiSj, over all the fantasy 35 00:02:56,648 --> 00:03:01,742 particles. The learning rule is then the change in 36 00:03:01,742 --> 00:03:07,079 the weights is proportional to the average you got with data, averaged over all 37 00:03:07,079 --> 00:03:12,349 training data, and the average you got with the fantasy particles when nothing 38 00:03:12,349 --> 00:03:17,067 was clamped. This works better than the learning rule 39 00:03:17,067 --> 00:03:21,780 that Terry Snofsky and I introduced, at least for full batch learning. 40 00:03:24,380 --> 00:03:28,834 However, it's difficult to apply this approach to mini batches. 41 00:03:28,834 --> 00:03:34,582 And the reason is, that by the time we get back to the same data vectorn if we're 42 00:03:34,582 --> 00:03:39,828 using mini batch learning, the weights would have been updated many times. 43 00:03:39,828 --> 00:03:45,648 Son the stored data specific particle for that data vector won't be anywhere near 44 00:03:45,648 --> 00:03:50,696 thermal equilibrium anymore. The hidden units won't be in thermal 45 00:03:50,696 --> 00:03:55,820 equilibrium with the visible units of the particle given the new weights. 46 00:03:57,780 --> 00:04:02,250 And again, we don't know how long we're going to have to run for, before we get 47 00:04:02,250 --> 00:04:08,229 close to equilibrium again. So, we can overcome this by making a 48 00:04:08,229 --> 00:04:11,680 strong assumption about how we understand the world. 49 00:04:11,680 --> 00:04:14,600 It's a kind of a epistemological assumption. 50 00:04:16,240 --> 00:04:21,222 We're going to assume that when a data vector is clamped, the set of good 51 00:04:21,222 --> 00:04:27,200 explanations, that is states of the hidden units, that act as interpretations of that 52 00:04:27,200 --> 00:04:32,457 data vector is uni-modal. That means we're saying that, for a given 53 00:04:32,457 --> 00:04:37,266 data vector, there aren't two very different explanations for that data 54 00:04:37,266 --> 00:04:40,585 vector. We assume that for sensory input, there is 55 00:04:40,585 --> 00:04:44,920 one correct explanation. And if we have a good model of the data, 56 00:04:44,920 --> 00:04:49,120 our model will give us one energy minimum for that data point. 57 00:04:50,420 --> 00:04:55,420 This is a restriction on the kinds of models we're willing to learn. 58 00:04:55,420 --> 00:05:01,008 We're going to use a learning algorithm that's incapable of learning models in 59 00:05:01,008 --> 00:05:05,420 which a data vector has many very different interpretations. 60 00:05:05,900 --> 00:05:10,957 Provided we're willing to make this assumption, we can use a very efficient 61 00:05:10,957 --> 00:05:15,946 method for approaching thermal equilibrium or an approximation to thermal 62 00:05:15,946 --> 00:05:21,300 equilibrium, with the data. It's called a mean field approximation. 63 00:05:23,260 --> 00:05:27,854 So, if we want to get the statistics right, we need to update the units 64 00:05:27,854 --> 00:05:32,762 statistically and sequentially. And the update rule is the probability of 65 00:05:32,762 --> 00:05:37,311 turning on unit, i, is the logistic function of the total input it receives 66 00:05:37,311 --> 00:05:44,310 from the other units in its bias. Where Sj, the state of another unit, is a 67 00:05:44,310 --> 00:05:52,262 stochastic binary thing. Now, instead of using that rule, we could 68 00:05:52,262 --> 00:05:58,708 say, we're not going to keep binary states for unit i, we're going to keep a real 69 00:05:58,708 --> 00:06:03,140 value between zero and one which we call a probability. 70 00:06:03,600 --> 00:06:09,971 And that probability at time t + one is going to be the output of the logistic 71 00:06:09,971 --> 00:06:13,923 function. The more you put in is the bias, and the 72 00:06:13,923 --> 00:06:18,520 sum of all the probabilities at time t times the weights. 73 00:06:18,960 --> 00:06:26,300 So, we're replacing this stochastic binary thing by a real value probability. 74 00:06:27,060 --> 00:06:33,298 And that's not quite right because this stochastic binary thing is inside a 75 00:06:33,298 --> 00:06:38,132 non-linear function. If it was a linear function, things would 76 00:06:38,132 --> 00:06:41,120 be fine. But because the logistics non-linear, we 77 00:06:41,120 --> 00:06:45,852 don't get the right answer when we put probabilities instead of fluctuating 78 00:06:45,852 --> 00:06:51,020 binary things inside. However, it works pretty well. 79 00:06:52,000 --> 00:06:57,211 It can go wrong by giving us biphasic oscillations because now we're going to be 80 00:06:57,211 --> 00:07:01,955 updating everything in parallel. And we can normally deal with those by 81 00:07:01,955 --> 00:07:06,499 using what's called damped mean field where we compute that pi of t1. 82 00:07:06,499 --> 00:07:08,771 + one. But, we don't go all the way there. 83 00:07:08,771 --> 00:07:14,250 We go to a point in between where we are now, and where this update wants us to go. 84 00:07:14,250 --> 00:07:20,487 So, in damped mean field, we'll go to lambda times the place we are now, plus 85 00:07:20,487 --> 00:07:26,058 one minus lambda times the place the update rule tells us to go to. 86 00:07:26,058 --> 00:07:34,221 And that will kill oscillations. Now, we can get an efficient mini batch 87 00:07:34,221 --> 00:07:40,068 learning procedure for both the machines, and this is what Russ Salakhutdinov 88 00:07:40,068 --> 00:07:45,149 realized. In the positive phase, we can initialize 89 00:07:45,149 --> 00:07:50,256 all probabilities at 0.5. We can clamp a data vector on the visible 90 00:07:50,256 --> 00:07:56,271 units, and we can update all the hidden units in parallel using mean field until 91 00:07:56,271 --> 00:07:59,880 convergence. And for mean field, you can recognize 92 00:07:59,880 --> 00:08:03,640 convergence is when the probability stop changing. 93 00:08:06,060 --> 00:08:12,680 And once we converged, we record PiPj for every connected pair of units. 94 00:08:14,240 --> 00:08:17,701 In the negative phase, we do what we were doing before. 95 00:08:17,701 --> 00:08:22,508 We keep a set of fantasy particles, each of which has a value that's a global 96 00:08:22,508 --> 00:08:25,457 configuration. And after each weight update, we 97 00:08:25,457 --> 00:08:29,880 sequentially update all the units in each fantasy particles a few times. 98 00:08:31,140 --> 00:08:36,441 And then, for every connected pair of units, we average SiSj, these stochastic 99 00:08:36,441 --> 00:08:41,952 binary things, over all fantasy particles. And the difference in those averages is 100 00:08:41,952 --> 00:08:45,719 the learning rule. That is, we change the weights by an 101 00:08:45,719 --> 00:08:53,150 amount proportional to that difference. If we want to make the updates for the 102 00:08:53,150 --> 00:08:58,013 fantasy particles more parallel, we can change the architecture of the Boltzmann 103 00:08:58,013 --> 00:09:02,194 machine. So, we're going to have a special 104 00:09:02,194 --> 00:09:07,221 architecture that allows alternating parallel updates for the fantasy 105 00:09:07,221 --> 00:09:11,873 particles. We have no connections within a layer, and 106 00:09:11,873 --> 00:09:17,600 we have no skip-layer connections, but we allow ourselves lots of hidden layers. 107 00:09:18,240 --> 00:09:23,720 So, the architecture looks like this. We call it a Deep Boltzmann Machine. 108 00:09:24,640 --> 00:09:29,279 And, it's really a general Boltzmann machine with lots of missing connections. 109 00:09:29,279 --> 00:09:32,806 All those skipped layer connections, if they were present. 110 00:09:32,806 --> 00:09:37,816 We wouldn't really have layers at all, it would just be a general Boltzmann machine. 111 00:09:37,816 --> 00:09:41,900 But, in this special architecture, there's something nice we can do. 112 00:09:43,400 --> 00:09:49,004 We can update the states for example the first hidden layer and the third hidden 113 00:09:49,004 --> 00:09:54,280 layer, given the current states of the visible units and the second hidden layer. 114 00:09:54,580 --> 00:09:59,147 And then, we can update the states of the visible units in the second hidden layer. 115 00:09:59,147 --> 00:10:02,043 And then, we can go back and update the other states, 116 00:10:02,043 --> 00:10:04,661 And we can go backwards and forwards like this. 117 00:10:04,661 --> 00:10:09,284 And so, we can update half the states of all the units in parallel and that'll be a 118 00:10:09,284 --> 00:10:15,958 correct update. So, one question is, if we have a deep 119 00:10:15,958 --> 00:10:21,215 Boltzmann machine like that trained by using mean field for the positive phase 120 00:10:21,215 --> 00:10:26,333 and updating fantasy particles by alternating between even layers and odd 121 00:10:26,333 --> 00:10:31,866 layers for the negative phase, can we learn good models of things like the MNIST 122 00:10:31,866 --> 00:10:34,840 digits, or indeed, a more complicated things? 123 00:10:35,180 --> 00:10:40,768 So, one way to tell whether you've learned a good model is after learning, you remove 124 00:10:40,768 --> 00:10:44,760 all the input and you just generate samples from your model. 125 00:10:45,120 --> 00:10:49,608 So, you run the Markov chain for a long time until it's burned in, and then you 126 00:10:49,608 --> 00:10:54,091 look at the samples you get. So, Russ Salakhutdinov used a eep 127 00:10:54,091 --> 00:11:00,697 Boltzmann machine to model MNIST digits using mean field for the positive phase, 128 00:11:00,697 --> 00:11:06,688 And alternating updates of the layers of the particles for the negative phase. 129 00:11:06,688 --> 00:11:13,517 And the real data looks like this. And the data that he got from his model 130 00:11:13,517 --> 00:11:18,643 looks like this. You can see, they're actually fairly 131 00:11:18,643 --> 00:11:21,288 similar. The model is producing things very like 132 00:11:21,288 --> 00:11:24,540 the MNIST digits so it's learned a pretty good model. 133 00:11:25,260 --> 00:11:31,527 So here's a puzzle. When he was learning that, he was using 134 00:11:31,527 --> 00:11:37,615 mini-batches with 100 data examples and also he was using 100 fantasy particles, 135 00:11:37,615 --> 00:11:41,572 the same 100 fantasy particles for every mini-batch. 136 00:11:41,572 --> 00:11:48,040 And the puzzle is, how can we estimate the negative statistics with only 100 negative 137 00:11:48,040 --> 00:11:55,077 examples to characterize the whole space? For all interesting problems, the global 138 00:11:55,077 --> 00:11:58,300 configurations base is going to be highly multi model. 139 00:11:59,520 --> 00:12:05,380 And how do we manage to find and represent all the nodes with only 100 particles? 140 00:12:06,400 --> 00:12:11,198 There's an interesting answer to this. The learning interacts with the Markov 141 00:12:11,198 --> 00:12:15,935 chain that's being used to gather the negative statistics, either one that's 142 00:12:15,935 --> 00:12:20,921 used to update the fantasy particles, and it interacts with it to make it have a 143 00:12:20,921 --> 00:12:27,512 much higher effective mixing rate. That means, we cannot analyze the learning 144 00:12:27,512 --> 00:12:31,970 by thinking of it being an outer loop that updates the weights, 145 00:12:31,970 --> 00:12:36,924 And an inner loop that gathers statistics with a fixed set of weights. 146 00:12:36,924 --> 00:12:41,100 The learning is affecting how effective that inner loop is. 147 00:12:42,960 --> 00:12:48,261 The reason for this is that whenever the fantasy particles outnumber the positive 148 00:12:48,261 --> 00:12:53,368 data, the energy surface is raised, and this has an effect on the mixing rate of 149 00:12:53,368 --> 00:12:56,730 the Markov chain. It makes the fantasies rush around 150 00:12:56,730 --> 00:13:00,660 hyper-actively, And they move around much faster than the 151 00:13:00,660 --> 00:13:04,820 mixing rate of the mark of chain to find better current static weights. 152 00:13:06,160 --> 00:13:09,300 So, here's a picture that shows you what's going on. 153 00:13:10,900 --> 00:13:16,560 If there's a mode in the energy surface that has more fantasy particles than data, 154 00:13:16,560 --> 00:13:22,014 the energy surface will be raised until the fantasy particles escape from that 155 00:13:22,014 --> 00:13:25,384 mode. So, the mode on the left has four fantasy 156 00:13:25,384 --> 00:13:30,313 particles and only two data points. So, the effect of the learning is going to be 157 00:13:30,313 --> 00:13:34,665 to raise the energy there. And that energy barrier might be much too 158 00:13:34,665 --> 00:13:39,530 high for a Markov chain to be able to cross, so the mixing rate will be very 159 00:13:39,530 --> 00:13:42,219 slow. But, the learning will actually spill 160 00:13:42,219 --> 00:13:46,700 those red particles out of that energy minimum by raising the minimum. 161 00:13:47,020 --> 00:13:51,148 And we get filled up and the fantasy particles will escape and go off somewhere 162 00:13:51,148 --> 00:13:56,654 else, to some other deep minimum. So, we can get out of minima that the 163 00:13:56,654 --> 00:14:01,960 Markov chain would not be able to get out of, at least, not in a reasonable time. 164 00:14:03,220 --> 00:14:08,545 So, what's going on here is the energy surface is really being used for two 165 00:14:08,545 --> 00:14:12,734 different purposes. The energy surface represents our model, 166 00:14:12,734 --> 00:14:18,201 but it's also being manipulated by the learning algorithm to make the Markov 167 00:14:18,201 --> 00:14:21,822 chain mix faster. Or rather, to have the effect of a 168 00:14:21,822 --> 00:14:27,113 faster-mixing Markov chain. Once the fantasy particles have filled up 169 00:14:27,113 --> 00:14:31,640 one hole, they'll rush off to somewhere else and deal with the next problem. 170 00:14:32,100 --> 00:14:36,863 An analogy for them is that their like investigative journalists who rush in to 171 00:14:36,863 --> 00:14:41,150 investigate some nasty problem. As soon as the publicity has caused that 172 00:14:41,150 --> 00:14:45,080 problem to be fixed, instead of saying, okay, everything is okay now. 173 00:14:45,080 --> 00:14:47,760 They rush off to find the next nasty problem.