1 00:00:01,060 --> 00:00:07,263 In this video we'll look in more detail at what happens when a neural network is 2 00:00:07,263 --> 00:00:13,084 discriminatorily fine tuned after it's first been pre-trained as a stack of 3 00:00:13,084 --> 00:00:17,830 Boltzmann machines. What we'll see is that the weights in the 4 00:00:17,830 --> 00:00:22,490 lower layers hardly change at all. But that nevertheless these tiny changes 5 00:00:22,490 --> 00:00:26,840 make a big difference in the classification performance of the neural 6 00:00:26,840 --> 00:00:30,880 net, because they put the decision boundaries in the right places. 7 00:00:32,120 --> 00:00:37,800 We also see that the effect of pre-training is to make deeper networks 8 00:00:37,800 --> 00:00:43,887 more effective than shallower ones. Without pre-training it's often the other 9 00:00:43,887 --> 00:00:49,021 way around. Finally, I give a fairly general argument 10 00:00:49,021 --> 00:00:54,033 about why it makes sense to start by doing generative training. 11 00:00:54,033 --> 00:00:59,920 And only after this is well under way to consider discriminative training. 12 00:01:01,280 --> 00:01:07,107 So now we're going to look at some work done in Yoshua Bengio's lab, examining 13 00:01:07,107 --> 00:01:13,460 what happens during fine tuning after a net's been generatively pre-trained. 14 00:01:13,460 --> 00:01:18,365 If you look on the left, there's the receptive fields in the first hidden 15 00:01:18,365 --> 00:01:23,473 layer of feature detectors, after the generative pre-training but before the 16 00:01:23,473 --> 00:01:27,034 fine tuning. Then on the right, there's the same 17 00:01:27,034 --> 00:01:32,095 receptive fields after the fine-tuning then you'll see almost nothing has 18 00:01:32,095 --> 00:01:35,173 changed. Nevertheless, the changes helped with 19 00:01:35,173 --> 00:01:40,005 discrimination. Here's an example of how pre-training 20 00:01:40,005 --> 00:01:45,220 reduces the test errors for networks with one hidden left. 21 00:01:45,220 --> 00:01:51,149 The task was discriminating between digits in a very large set of distorted 22 00:01:51,149 --> 00:01:54,505 digits. And you can see that after the back 23 00:01:54,505 --> 00:02:00,200 propagation fine tuning, the networks with pre-training almost always did 24 00:02:00,200 --> 00:02:03,790 better than the networks without pre-training. 25 00:02:03,790 --> 00:02:07,791 The effect gets even bigger, if you use deeper networks. 26 00:02:07,791 --> 00:02:12,884 So here you can see that there's basically no overlap between the two 27 00:02:12,884 --> 00:02:16,812 distributions. And the deep networks with pre-training 28 00:02:16,812 --> 00:02:22,051 have got better than the shallow networks, and the deep networks without 29 00:02:22,051 --> 00:02:26,460 pre-training have got worse than the shallow networks. 30 00:02:26,460 --> 00:02:32,588 This is showing you the classification error and the variation classification 31 00:02:32,588 --> 00:02:39,520 error as you change the number of layers when you're not doing pre-training. 32 00:02:39,520 --> 00:02:43,085 And you can see that two layers appears to be best. 33 00:02:43,085 --> 00:02:48,049 And by the time you've got four layers, you're doing considerably worse. 34 00:02:48,049 --> 00:02:53,222 By contrast, if you use pre-training, four layers is better than two layers. 35 00:02:53,222 --> 00:02:55,735 There's much less variation and DR is lower. 36 00:02:56,983 --> 00:03:02,464 This is a visualization made with Teeson of what happens to the weights during 37 00:03:02,464 --> 00:03:07,598 training for both pretrained and non-pretrained networks and they are all 38 00:03:07,598 --> 00:03:12,871 plotted in the same space but you can see they form two distinct classes in 39 00:03:12,871 --> 00:03:16,270 networks. One's at the top and that works without 40 00:03:16,270 --> 00:03:21,196 pretraining and the ones at the bottom and that works with pretraining. 41 00:03:21,196 --> 00:03:24,180 Each point shows a model in function space. 42 00:03:24,180 --> 00:03:29,083 It's no use comparing weight vectors because two nets might differ by having 43 00:03:29,083 --> 00:03:34,050 two of the hidden units swapped round. So they behave exactly the same way, but 44 00:03:34,050 --> 00:03:39,329 the weights would look very different. In order to compare them, you have to 45 00:03:39,329 --> 00:03:42,730 compare the functions of the implementation. 46 00:03:42,730 --> 00:03:48,678 And a way to do that is to have a suite of test cases and look at the outputs the 47 00:03:48,678 --> 00:03:54,263 networks produce on those test cases. And then concatenate those outputs into 48 00:03:54,263 --> 00:03:58,471 one great long vector. And so if two networks produce very 49 00:03:58,471 --> 00:04:04,056 similar outputs for all the test cases, that concatenated vector will be very 50 00:04:04,056 --> 00:04:08,989 similar for the two networks. Now you take those concatenated output 51 00:04:08,989 --> 00:04:13,524 vectors and you plot those in 2D using t-SNE. 52 00:04:13,524 --> 00:04:21,458 The colors show the stages of training so if you look at the networks at the top, 53 00:04:21,458 --> 00:04:27,362 there's an initial blob in dark blue. And then you can see that it's all moving 54 00:04:27,362 --> 00:04:32,550 roughly the same direction. In other words, the networks after one 55 00:04:32,550 --> 00:04:38,077 epoch of learning are all more similar to one another then they are to the initial 56 00:04:38,077 --> 00:04:41,140 networks. That's even more pronounced with the 57 00:04:41,140 --> 00:04:46,268 pre-trained networks at the bottom. So the color tells you which epoch you're 58 00:04:46,268 --> 00:04:48,865 in. The trajectories at the top without 59 00:04:48,865 --> 00:04:54,259 pre-training show that different networks end up in different places in function 60 00:04:54,259 --> 00:04:56,790 space. And they're quite widely spread. 61 00:04:56,790 --> 00:05:02,087 The trajectories of the bottom show that with pre-training, you end up in a quite 62 00:05:02,087 --> 00:05:07,058 different region of function space. And the networks tend to be more similar 63 00:05:07,058 --> 00:05:11,791 to one another. But the main point is there's no overlap. 64 00:05:11,791 --> 00:05:17,247 The kinds of solutions you find, if you pre-train the networks generatively, are 65 00:05:17,247 --> 00:05:22,703 just qualitatively different from the kinds of solutions you find if you start 66 00:05:22,703 --> 00:05:26,915 with small random words. The last thing I want to say in this 67 00:05:26,915 --> 00:05:30,580 video is to explain why pre-training makes sense. 68 00:05:30,580 --> 00:05:35,858 So let's imagine that the way we generated pairs of an image and a label 69 00:05:35,858 --> 00:05:41,209 was by taking the stuff in the real world, using that to generate an image, 70 00:05:41,209 --> 00:05:45,320 for example by taking a photograph of something. 71 00:05:45,320 --> 00:05:50,868 And then having generated the image we attached a label to it that didn't depend 72 00:05:50,868 --> 00:05:55,251 on the stuff in the world. So contingent on the image itself, the 73 00:05:55,251 --> 00:06:01,080 stuff in the world is relevant the label thus depends on the pixels in the image 74 00:06:01,080 --> 00:06:07,130 That would be the case for example if the label told us whether the top left pixel 75 00:06:07,130 --> 00:06:14,428 was similar to the bottom right pixel. Now if we generated images that way, then 76 00:06:14,428 --> 00:06:20,120 it would make sense to try and learn a mapping from images to labels. 77 00:06:20,120 --> 00:06:26,848 Because the labels depend directly on the images but actually, it's more plausible 78 00:06:26,848 --> 00:06:31,775 that the way we generate image label pairs are by there being stuff in the 79 00:06:31,775 --> 00:06:36,769 world that gives rise to the image. And the reason the image has the name it 80 00:06:36,769 --> 00:06:42,025 has is because of the stuff in the world, not because of the pixels in the image. 81 00:06:42,025 --> 00:06:44,587 So you see a cow. You take a photograph. 82 00:06:44,587 --> 00:06:49,778 And you call that a photograph of a cow, because you were looking at a cow when 83 00:06:49,778 --> 00:06:52,669 you took it. Now the point is, there's a high 84 00:06:52,669 --> 00:06:56,480 bandwidth from the stuff in the world to the image. 85 00:06:56,480 --> 00:07:01,160 And there's a low bandwidth from the stuff in the world to the label. 86 00:07:01,160 --> 00:07:06,040 For example if I just say cow, you don't know whether the cow is upside down, 87 00:07:06,040 --> 00:07:11,178 whether it was brown or black and white, whether it was alive or dead, how big it 88 00:07:11,178 --> 00:07:16,123 was, what else was in the image, whether it was facing you or facing away from 89 00:07:16,123 --> 00:07:19,668 you. One of those things aren't conveyed by 90 00:07:19,668 --> 00:07:23,466 the label. If you see an image with thousands and 91 00:07:23,466 --> 00:07:27,651 thousands of pixels, you typically know all of those things. 92 00:07:27,651 --> 00:07:33,468 You get much, much more information about the causes of an image by looking at the 93 00:07:33,468 --> 00:07:37,369 image then you do by looking at the label of the image. 94 00:07:37,369 --> 00:07:42,902 So in that situation, where there's a high bandwidth pathway from the world to 95 00:07:42,902 --> 00:07:48,719 the image, and a low bandwidth image from the world to the label, because the label 96 00:07:48,719 --> 00:07:54,153 typically contains very few bits. It makes much more sense to try and 97 00:07:54,153 --> 00:08:00,424 recover the label by first inverting the high boundary pathway to get back to the 98 00:08:00,424 --> 00:08:06,144 stuff in the world that caused the image. And then having recovered the stuff in 99 00:08:06,144 --> 00:08:10,850 the world that caused the image, to decide what label it would be given. 100 00:08:10,850 --> 00:08:17,020 So that's a much more plausible model of how we assign names to things in images. 101 00:08:17,020 --> 00:08:23,342 And that justifies having a pre-training phase where you try and go from the image 102 00:08:23,342 --> 00:08:29,513 to its underlying causes, followed by a discriminative phase where you try and go 103 00:08:29,513 --> 00:08:32,712 from those underlying causes to the label. 104 00:08:32,712 --> 00:08:38,806 And perhaps you slightly fine tune the mapping from the image to the underlying 105 00:08:38,806 --> 00:08:39,340 causes.