1 00:00:00,000 --> 00:00:04,612 [MUSIC] 2 00:00:04,612 --> 00:00:08,595 All right, let's see the mean field approximation. 3 00:00:08,595 --> 00:00:10,274 Here is how it works. 4 00:00:10,274 --> 00:00:13,996 We select a family of distributions Q as variational inference. 5 00:00:13,996 --> 00:00:17,459 So we select a family of distributions in the following way. 6 00:00:17,459 --> 00:00:21,888 It is a set of all distributions that are factorized 7 00:00:21,888 --> 00:00:26,426 over the dimension itself, the latent variables. 8 00:00:26,426 --> 00:00:33,808 So, this will be a product of qi, the distribution of the ith latent variable. 9 00:00:33,808 --> 00:00:37,921 And then we do an optimization by minimizing the KL diversions between 10 00:00:37,921 --> 00:00:41,208 the variational distribution and the full posterior. 11 00:00:41,208 --> 00:00:48,617 So here is an example, here we have a distribution over two random variables. 12 00:00:48,617 --> 00:00:51,557 The true posterior is the star and 13 00:00:51,557 --> 00:00:56,598 the factorized one would be the q1(z1) q2(z2). 14 00:00:56,598 --> 00:00:59,932 I mentioned that the true posterior is a normal 15 00:00:59,932 --> 00:01:04,564 distribution with some covariance matox sigma and the mean 0. 16 00:01:04,564 --> 00:01:09,397 Then the factorized distribution would be also normal, but 17 00:01:09,397 --> 00:01:13,101 it would have a diagonal convergence matrix. 18 00:01:13,101 --> 00:01:17,908 If this was our true posterior, then the approximation would be like this. 19 00:01:17,908 --> 00:01:22,395 So not here that we don't have the diagonal of the diagonal elements of 20 00:01:22,395 --> 00:01:24,990 the current matrix in the red Gaussian. 21 00:01:26,890 --> 00:01:29,731 So optimization works as follows. 22 00:01:29,731 --> 00:01:34,562 We'll start with some distribution q, and we'll perform 23 00:01:34,562 --> 00:01:39,696 a coordinate descend with respect to different elements of q. 24 00:01:39,696 --> 00:01:44,283 And at first side we'll optimize with respect to q1, we'll get a new 25 00:01:44,283 --> 00:01:49,679 distribution, then on the second set we optimize with respect to q2, and so on. 26 00:01:49,679 --> 00:01:54,610 This will be loop and we'll optimize in two conversions. 27 00:01:54,610 --> 00:01:58,786 Let's now drive the formulas for update and each step. 28 00:02:01,731 --> 00:02:04,785 Our plan for now is to draw the formulas for 29 00:02:04,785 --> 00:02:08,195 the updates of the mean field approximation. 30 00:02:08,195 --> 00:02:12,078 So, here's our step of coordinate descend. 31 00:02:12,078 --> 00:02:15,962 We are trying to minimize the KL conversions with respect to one 32 00:02:15,962 --> 00:02:17,070 dimension, qk. 33 00:02:17,070 --> 00:02:21,240 Let's write down the values of KL diversions. 34 00:02:21,240 --> 00:02:25,904 So this would be an integral of 35 00:02:25,904 --> 00:02:31,499 the product of our all dimensions, 36 00:02:31,499 --> 00:02:37,280 i from 1 to d, qi times the logarithm 37 00:02:37,280 --> 00:02:41,959 of the ratio, so logarithm. 38 00:02:41,959 --> 00:02:47,134 And again, the product for 39 00:02:47,134 --> 00:02:52,986 all dimensions, q i over p*. 40 00:02:52,986 --> 00:02:58,201 And we integrate it over all z. 41 00:02:58,201 --> 00:03:02,420 We can take out the products from the logarithm and 42 00:03:02,420 --> 00:03:05,544 it will be the sum of the logarithms. 43 00:03:05,544 --> 00:03:11,596 And we can also take the denominator s and as a separate integral. 44 00:03:11,596 --> 00:03:16,310 Now we'll have, 45 00:03:16,310 --> 00:03:21,965 sum for i from 1 to d, 46 00:03:21,965 --> 00:03:28,888 integral of the product. 47 00:03:28,888 --> 00:03:33,139 Again, let's write it down as j, so 48 00:03:33,139 --> 00:03:37,401 that we'll have separate indices. 49 00:03:37,401 --> 00:03:41,405 So j for 1 to d, 50 00:03:41,405 --> 00:03:46,643 q j logarithm of q i, 51 00:03:46,643 --> 00:03:52,806 d z, minus the integral 52 00:03:52,806 --> 00:03:58,367 of the denominator. 53 00:03:58,367 --> 00:04:03,594 So again, it would be integral, 54 00:04:03,594 --> 00:04:09,762 product over j from 1 to d, qj log p*. 55 00:04:09,762 --> 00:04:13,740 [SOUND] Easy. 56 00:04:16,106 --> 00:04:20,941 All right, we're interested only in the component qk. 57 00:04:20,941 --> 00:04:24,517 So let's separate all this summation into two terms. 58 00:04:24,517 --> 00:04:28,702 One with i = k and all others. 59 00:04:28,702 --> 00:04:30,129 So this would be equal to 60 00:04:34,623 --> 00:04:39,246 Integral j 61 00:04:39,246 --> 00:04:43,868 from 1 to d 62 00:04:43,868 --> 00:04:48,497 qj log qk. 63 00:04:48,497 --> 00:04:52,778 This is the term that we're particularly interested in. 64 00:04:52,778 --> 00:04:58,189 d z and plus sum over all other components, 65 00:04:58,189 --> 00:05:05,716 so it would be sum of i not equal to k, and the same integral. 66 00:05:05,716 --> 00:05:12,290 So it would be the project 67 00:05:12,290 --> 00:05:16,672 of j from 1 to d, 68 00:05:16,672 --> 00:05:20,428 qj log qi dz and 69 00:05:20,428 --> 00:05:25,123 minus this term, 70 00:05:25,123 --> 00:05:32,010 product for j from 1 to d, 71 00:05:32,010 --> 00:05:36,408 qj log of p*dz. 72 00:05:36,408 --> 00:05:42,699 So let's find out which terms are constants with respect to qk. 73 00:05:42,699 --> 00:05:45,591 Let's start by rewriting the first term. 74 00:05:45,591 --> 00:05:50,727 So it actually equals to the integral 75 00:05:50,727 --> 00:05:55,364 of qk times the logarithm of qk and 76 00:05:55,364 --> 00:06:03,329 also we can eventually integrate all other dimensions. 77 00:06:03,329 --> 00:06:10,199 So here is an integral of the product 78 00:06:10,199 --> 00:06:15,695 over j not equal to k, qj dz, 79 00:06:15,695 --> 00:06:20,733 let me write it down like, 80 00:06:20,733 --> 00:06:23,719 not equal to j. 81 00:06:23,719 --> 00:06:30,595 So those are all variables except for the k, short k here. 82 00:06:30,595 --> 00:06:36,719 And finally, we integrate over zk. 83 00:06:36,719 --> 00:06:40,647 So this term actually equals to 1, since q is a distribution and 84 00:06:40,647 --> 00:06:44,161 the integral of the distribution actually equals to 1. 85 00:06:46,767 --> 00:06:52,480 So this equals to 1, and 86 00:06:52,480 --> 00:06:58,738 finally we get the integral 87 00:06:58,738 --> 00:07:04,735 of qk logarithm of qk dzk. 88 00:07:04,735 --> 00:07:08,079 All right, so term surely depends on qk. 89 00:07:08,079 --> 00:07:12,777 However, these terms will have a similar form, so 90 00:07:12,777 --> 00:07:19,350 those would be equal to the integral of qj times the logarithm of qj dzj. 91 00:07:19,350 --> 00:07:24,062 And since these don't depend on qk, 92 00:07:24,062 --> 00:07:29,392 we can say that these are just constants. 93 00:07:29,392 --> 00:07:35,210 Those are just constant, so constant. 94 00:07:35,210 --> 00:07:41,836 All right, so here's our distribution again, let's continue deriving the formula 95 00:07:47,774 --> 00:07:53,893 This would be equal to an integral 96 00:07:53,893 --> 00:07:58,919 of qk times the logarithm of 97 00:07:58,919 --> 00:08:04,169 qk dzk minus this integral. 98 00:08:04,169 --> 00:08:09,539 Let me separate the dimension number k out from the integral. 99 00:08:09,539 --> 00:08:14,236 So this would be integral of qk, here I will 100 00:08:14,236 --> 00:08:19,589 derive the integral over all other dimensions. 101 00:08:19,589 --> 00:08:25,315 So product of over j not equal to k, 102 00:08:25,315 --> 00:08:30,449 qj times logarithm of p* d, so 103 00:08:30,449 --> 00:08:37,558 here we integrate over oj is not equal to k, 104 00:08:37,558 --> 00:08:45,060 so we come right down again is z not equal to k, 105 00:08:45,060 --> 00:08:50,806 and we have to close the brackets. 106 00:08:50,806 --> 00:08:57,621 So here d z and I guess that's it. 107 00:08:57,621 --> 00:09:01,521 All right, let's now put into one integral. 108 00:09:01,521 --> 00:09:07,193 So we can group these as an integral of qk times 109 00:09:07,193 --> 00:09:13,328 the difference between this term and this term. 110 00:09:13,328 --> 00:09:18,520 So, we'll have an integral, 111 00:09:18,520 --> 00:09:24,712 qk times the following difference, 112 00:09:24,712 --> 00:09:30,314 logarithm of qk minus integral, 113 00:09:32,759 --> 00:09:36,964 The product for jl equal to k, 114 00:09:36,964 --> 00:09:42,178 qj log p*, and dz not equal to k, and 115 00:09:42,178 --> 00:09:47,225 finally, we close this bracket and 116 00:09:47,225 --> 00:09:53,133 integrate over the last dimension, zk. 117 00:09:53,133 --> 00:09:56,183 All right, what is this term? 118 00:09:56,183 --> 00:10:01,173 This term actually equals to the expectation of 119 00:10:01,173 --> 00:10:06,921 logarithmic p* over all dimensions except for the k. 120 00:10:06,921 --> 00:10:11,191 So we have this term to be equal to 121 00:10:11,191 --> 00:10:15,777 the expected value for q except for 122 00:10:15,777 --> 00:10:21,166 the k, so let me write it down like this. 123 00:10:21,166 --> 00:10:26,303 The logarithm of p*. 124 00:10:26,303 --> 00:10:32,606 So let's write it down as some function of the fk. 125 00:10:32,606 --> 00:10:36,849 For example, h(zk). 126 00:10:36,849 --> 00:10:39,683 So actually, we can turn it out, 127 00:10:39,683 --> 00:10:44,148 we can modify it a little bit and get a distribution. 128 00:10:44,148 --> 00:10:48,759 For this, we'll just have to renormalize the exponent of. 129 00:10:48,759 --> 00:10:54,491 So, let's have some new distribution that equals 130 00:10:54,491 --> 00:11:00,502 to exponent of h(zk), and had to renormalized it. 131 00:11:00,502 --> 00:11:05,566 So how to integrate over zk, 132 00:11:05,566 --> 00:11:10,846 e to the h(zk) prime, dzk. 133 00:11:10,846 --> 00:11:16,165 This will be our new distribution, let's call it t. 134 00:11:25,238 --> 00:11:29,372 Actually, I should write here + const. 135 00:11:33,737 --> 00:11:35,586 And here also has some constant. 136 00:11:37,813 --> 00:11:42,948 Now, we can notice that it actually equals to 137 00:11:42,948 --> 00:11:49,040 the KL diversions, between distributions qk and t. 138 00:11:49,040 --> 00:11:54,038 So this will be an integral of qk, 139 00:11:54,038 --> 00:11:59,928 the lower end of the ratio between qk and 140 00:11:59,928 --> 00:12:05,639 t, dzk, plus some new constant since 141 00:12:05,639 --> 00:12:11,368 we don't have here the denominator. 142 00:12:11,368 --> 00:12:14,717 So, plus some new constant. 143 00:12:14,717 --> 00:12:17,717 And we try to minimize it. 144 00:12:17,717 --> 00:12:23,473 So, again this is a KL divergence 145 00:12:23,473 --> 00:12:28,004 between qk and t + const. 146 00:12:28,004 --> 00:12:32,770 So we just have to minimize the KL diversion. 147 00:12:32,770 --> 00:12:36,593 And so we already know what the answer is, 148 00:12:36,593 --> 00:12:42,616 we have to take the qk equal to the distribution t as written here. 149 00:12:42,616 --> 00:12:46,689 So, qk goes through t, 150 00:12:46,689 --> 00:12:54,443 the better way to write it down is as follows. 151 00:12:54,443 --> 00:12:59,078 So actually while we say is that those two terms are equal, 152 00:12:59,078 --> 00:13:02,399 we say that this term equals to this term. 153 00:13:02,399 --> 00:13:10,159 And so we have the log qk = h(zk), 154 00:13:10,159 --> 00:13:15,594 there is the expectation 155 00:13:15,594 --> 00:13:21,805 over all variable except for 156 00:13:21,805 --> 00:13:27,756 the k, so expectation over 157 00:13:27,756 --> 00:13:33,994 q without k log p* + const. 158 00:13:33,994 --> 00:13:36,292 And so this is our final formula. 159 00:13:43,087 --> 00:13:47,569 We'll see how to use it in the next video. 160 00:13:47,569 --> 00:13:57,569 [MUSIC]