1 00:00:03,750 --> 00:00:08,070 Let's discuss, how can we apply Markov chain Monte Carlo, 2 00:00:08,070 --> 00:00:10,450 to train Latent Dirichlet Allocation model, 3 00:00:10,450 --> 00:00:12,415 or LDA for short. 4 00:00:12,415 --> 00:00:16,135 We have already discussed LDA model in the previous week, 5 00:00:16,135 --> 00:00:18,400 on variational inference, and derived 6 00:00:18,400 --> 00:00:22,485 some complicated formulas to do the training in the LDA model. 7 00:00:22,485 --> 00:00:27,208 But now, let's see how can the Markov chain Monte Carlo be applied to the same problem, 8 00:00:27,208 --> 00:00:31,115 and what are the pros and cons of this approach. 9 00:00:31,115 --> 00:00:34,480 Recall that LDA model is a model to find topics in 10 00:00:34,480 --> 00:00:38,050 a corpus of documents and the idea of LDA is that, 11 00:00:38,050 --> 00:00:43,460 each document has a distribution over topics. 12 00:00:43,460 --> 00:00:46,190 For example, one document has a distribution, 13 00:00:46,190 --> 00:00:51,670 can be like 80% about cats and 20% about dogs. 14 00:00:51,670 --> 00:00:55,170 And then each topic is a distribution over words. 15 00:00:55,170 --> 00:01:01,007 For example, topic about cats can, in principle, 16 00:01:01,007 --> 00:01:03,140 can generate any word, 17 00:01:03,140 --> 00:01:09,795 but it is much more likely to generate words like cat or meow than others. 18 00:01:09,795 --> 00:01:13,510 And so, when we look at a particular document, for example, 19 00:01:13,510 --> 00:01:16,030 cat meowed on the dog, then, 20 00:01:16,030 --> 00:01:19,885 we can imagine that this document was generated as follows. 21 00:01:19,885 --> 00:01:27,095 First of all, we decided on which distribution over topics we will use. 22 00:01:27,095 --> 00:01:28,795 So here, it's 80 for cats, 23 00:01:28,795 --> 00:01:30,840 and 20 for dogs. 24 00:01:30,840 --> 00:01:35,600 And then, we decided that for each word, 25 00:01:35,600 --> 00:01:38,395 we must decide on a topic. 26 00:01:38,395 --> 00:01:42,423 So for our first word, we may flip biased coin, 27 00:01:42,423 --> 00:01:46,180 and with probability, 80% will be cats. 28 00:01:46,180 --> 00:01:49,405 So, for our first word, we happened to trade the topic, cats. 29 00:01:49,405 --> 00:01:51,980 And now, we have to generate the first word itself. 30 00:01:51,980 --> 00:01:54,785 So we know the topic, I'll have to generate the word. 31 00:01:54,785 --> 00:01:59,314 And now, we can generate any word from our logic tree, 32 00:01:59,314 --> 00:02:01,894 but we happen to trade the word, cat, 33 00:02:01,894 --> 00:02:03,301 as the first one, 34 00:02:03,301 --> 00:02:05,260 and then we repeat, and when we, again, 35 00:02:05,260 --> 00:02:07,020 sample topic for the second word, 36 00:02:07,020 --> 00:02:10,385 and let's say it again happened to be cats. 37 00:02:10,385 --> 00:02:13,120 Then let's say we generated the word, meow. 38 00:02:13,120 --> 00:02:16,810 And then for our last word, let's say we happen to generate the topic, 39 00:02:16,810 --> 00:02:20,530 dogs, and to generate the word, dog. 40 00:02:20,530 --> 00:02:22,465 Notice that for this model, 41 00:02:22,465 --> 00:02:24,380 the order of words doesn't matter. 42 00:02:24,380 --> 00:02:27,290 It's kind of a back of words approach. 43 00:02:27,290 --> 00:02:30,640 But anyway, they're more than flexible 44 00:02:30,640 --> 00:02:34,720 enough to discover meaningful topics in the documents. 45 00:02:34,720 --> 00:02:39,100 If you put it a little more mathematically, 46 00:02:39,100 --> 00:02:42,870 you can say that each document has a distribution over topics, 47 00:02:42,870 --> 00:02:46,350 Theta d, so our document D has a distribution Theta d, 48 00:02:46,350 --> 00:02:48,700 which is 80 and 20 in this case, 49 00:02:48,700 --> 00:02:50,560 and each word has a topic, 50 00:02:50,560 --> 00:02:52,855 which we don't know, it's latent verbal. 51 00:02:52,855 --> 00:02:55,940 But for words from one to n, 52 00:02:55,940 --> 00:03:01,895 Z_I is the topic of the Ith word in the document, D. 53 00:03:01,895 --> 00:03:05,833 And these topics are generated from the distribution Theta d. 54 00:03:05,833 --> 00:03:10,275 And now each word itself is generated from the appropriate topic. 55 00:03:10,275 --> 00:03:11,965 So the first topic is about cats, 56 00:03:11,965 --> 00:03:14,815 then we generate the word one from the topic cats, 57 00:03:14,815 --> 00:03:16,745 from the distribution for the topic cats. 58 00:03:16,745 --> 00:03:19,845 And this distribution is called Phi. 59 00:03:19,845 --> 00:03:24,115 So Phi cats is the distribution for the words for the topic cats. 60 00:03:24,115 --> 00:03:29,800 And thus, we generate the words themselves which we observe and use to train our model, 61 00:03:29,800 --> 00:03:33,250 to find all the necessary latent variables and parameters. 62 00:03:33,250 --> 00:03:36,445 And in the variational inference week, 63 00:03:36,445 --> 00:03:41,950 we discuss how to use expectation maximization algorithm to train these LDA model. 64 00:03:41,950 --> 00:03:43,170 So on the E-step, 65 00:03:43,170 --> 00:03:47,110 we tried to find the posterior on the latent variables Zeta and Theta, 66 00:03:47,110 --> 00:03:49,030 but it was too hard to do analytically, 67 00:03:49,030 --> 00:03:50,855 so we decided to approximate. 68 00:03:50,855 --> 00:03:52,655 We decided to find 69 00:03:52,655 --> 00:03:57,698 the approximation of the posterior in the family of factorized distribution, 70 00:03:57,698 --> 00:03:59,790 which is the idea of variational inference. 71 00:03:59,790 --> 00:04:01,180 And then on the M-step, 72 00:04:01,180 --> 00:04:03,804 we maximized some expected value of algorithm of 73 00:04:03,804 --> 00:04:09,315 the joint probability with respect to this approximate posterior distribution. 74 00:04:09,315 --> 00:04:15,490 And we spent quite a few time to derive necessary formulas and to find 75 00:04:15,490 --> 00:04:18,370 a good solution to define 76 00:04:18,370 --> 00:04:25,300 analytical formulas for E and M steps for this EM algorithm for LD. 77 00:04:25,300 --> 00:04:28,480 If you summarize the model we have here for LD, 78 00:04:28,480 --> 00:04:30,010 we have known data W, 79 00:04:30,010 --> 00:04:32,410 which is what's in our documents. 80 00:04:32,410 --> 00:04:33,522 We have unknown parameters Phi, 81 00:04:33,522 --> 00:04:37,460 the distribution over words for each topic, 82 00:04:37,460 --> 00:04:40,765 and we have a non-latent variables, Zeta and Theta. 83 00:04:40,765 --> 00:04:42,530 So, for each word, 84 00:04:42,530 --> 00:04:45,410 Zeta gives us the topic of this word and for each documents, 85 00:04:45,410 --> 00:04:50,065 Theta gives us the distribution over topics for this document. 86 00:04:50,065 --> 00:04:54,475 And kind of makes sense that we decided that Phi would be parameter, 87 00:04:54,475 --> 00:04:56,965 and Zeta and Theta will be latent variables, 88 00:04:56,965 --> 00:05:01,829 because the number of variables Zeta and Theta, 89 00:05:01,829 --> 00:05:05,130 depends on the amount of training data they have. 90 00:05:05,130 --> 00:05:10,020 So, for each document, we have to add one new Theta and a set of new Zeds. 91 00:05:10,020 --> 00:05:13,230 So it makes sense to make it in latent variables. 92 00:05:13,230 --> 00:05:17,350 So our parameterization will not grow with respect to the training data size. 93 00:05:17,350 --> 00:05:21,105 But Phi it has a fixed size, 94 00:05:21,105 --> 00:05:24,603 number of topics times the number of words in the dictionary. 95 00:05:24,603 --> 00:05:30,620 So weighted parameters in our variational inference approach to L.D. 96 00:05:30,620 --> 00:05:35,218 But let's see how can we deal with this model if we go full Bayesian. 97 00:05:35,218 --> 00:05:40,135 So if we treat Phi as a latent variable itself. 98 00:05:40,135 --> 00:05:43,740 So now we don't have any parameters told. 99 00:05:43,740 --> 00:05:47,200 We don't want to estimate any numbers but we want to marginalize out 100 00:05:47,200 --> 00:05:52,557 all the latent variables and do the predictions about some things. 101 00:05:52,557 --> 00:05:54,850 We will apply a Markov chain Monte Carlo 102 00:05:54,850 --> 00:05:58,082 for this model of full Bayesian inference for LD. 103 00:05:58,082 --> 00:06:03,603 Although we could have applied Markov chain Monte Carlo to the EM algorithm, 104 00:06:03,603 --> 00:06:10,025 but let's just use this full Bayesian model as an illustration. 105 00:06:10,025 --> 00:06:12,775 If we want to apply Markov chain Monte Carlo, 106 00:06:12,775 --> 00:06:17,906 we just need to sample all latent variables from the posterior distribution, 107 00:06:17,906 --> 00:06:21,179 for example, by using the Gibbs sampling. 108 00:06:21,179 --> 00:06:22,780 If we want to use Gibbs sampling, 109 00:06:22,780 --> 00:06:25,060 we will start with some initialization of 110 00:06:25,060 --> 00:06:28,060 all our latent variables and then we will sample 111 00:06:28,060 --> 00:06:34,420 the first coordinates of the first iteration of the latent variable Phi, 112 00:06:34,420 --> 00:06:38,680 from a conditional distribution conditioned on the initialized variables. 113 00:06:38,680 --> 00:06:42,925 Then we'll do the same thing for the second coordinates, 114 00:06:42,925 --> 00:06:49,220 and we'll condition on the dimension we just generated on the previous sub-step, 115 00:06:49,220 --> 00:06:53,378 and also on the stuff we have from initialization. 116 00:06:53,378 --> 00:06:56,915 And then we'll repeat any iterations to obtain the full Phi, 117 00:06:56,915 --> 00:07:00,388 and we'll do the same thing for Theta and Zed. 118 00:07:00,388 --> 00:07:03,305 So for Theta, we'll generate its Ith coordinates 119 00:07:03,305 --> 00:07:09,769 yet by considering the condition distribution on Theta_i, 120 00:07:09,769 --> 00:07:13,155 given all the Phis we have already generated, 121 00:07:13,155 --> 00:07:16,383 and all the Thetas we have already generated 122 00:07:16,383 --> 00:07:19,930 in the initial values for the rest of the Thetas and for the Zed. 123 00:07:19,930 --> 00:07:21,906 And the same thing for the Zeta. 124 00:07:21,906 --> 00:07:25,105 This is usually Gibbs sampling , it's really easy. 125 00:07:25,105 --> 00:07:30,145 I mean you've going to implement like five lines of Python and nothing complicated here. 126 00:07:30,145 --> 00:07:37,985 And we will do it in iterations to get a few hundreds of iterations of Gibbs sampling, 127 00:07:37,985 --> 00:07:47,835 wait until convergence, and then obtain a sample from the desired posterior distribution. 128 00:07:47,835 --> 00:07:50,580 Once we have at least one sample from this posterior, 129 00:07:50,580 --> 00:07:54,735 we can just use it as an estimation of our parameters. 130 00:07:54,735 --> 00:07:59,183 So one sample from the posterior on Phi for example, 131 00:07:59,183 --> 00:08:03,051 can be used as just the trade value of Phi. 132 00:08:03,051 --> 00:08:06,560 Or you can average a few Phis if you want, 133 00:08:06,560 --> 00:08:10,110 or you can maybe make predictions with average of a few Phis. 134 00:08:10,110 --> 00:08:14,150 So you can use the samples to do almost anything with your model. 135 00:08:14,150 --> 00:08:16,610 You can do that, it's really easy. 136 00:08:16,610 --> 00:08:20,400 But, the problem is that it will not be very efficient. 137 00:08:20,400 --> 00:08:21,945 Compared to variational inference, 138 00:08:21,945 --> 00:08:24,455 this approach will not be very fast. 139 00:08:24,455 --> 00:08:29,190 And let's see what can we do about it to make it faster in this particular case. 140 00:08:29,190 --> 00:08:31,815 The idea here is to look closely at the model, 141 00:08:31,815 --> 00:08:35,325 and try to marginalize something out. 142 00:08:35,325 --> 00:08:40,513 Basic Gibbs sampling is trying to integrate something. 143 00:08:40,513 --> 00:08:46,565 We're substituting expected value which is integral, with Monte Carlo samples. 144 00:08:46,565 --> 00:08:49,740 And, if we can integrate out some parts analytically, 145 00:08:49,740 --> 00:08:53,280 and then use Monte Carlo only for the rest, 146 00:08:53,280 --> 00:08:56,850 then the whole scheme will probably be more efficient. 147 00:08:56,850 --> 00:09:00,864 Let's see if we can integrate out some parts of the model analytically. 148 00:09:00,864 --> 00:09:01,930 This is the model, 149 00:09:01,930 --> 00:09:03,860 we have prior distribution Theta, 150 00:09:03,860 --> 00:09:11,250 and Theta is just probabilities for topics for each document. 151 00:09:11,250 --> 00:09:14,430 For example, Theta for the first document may be 80 and 20, 152 00:09:14,430 --> 00:09:16,885 80 for cats, and 20 dogs. 153 00:09:16,885 --> 00:09:20,880 And we have a prior distribution. 154 00:09:20,880 --> 00:09:23,935 There's Theta being Dirichlet distribution, which is natural. 155 00:09:23,935 --> 00:09:27,195 I mean, it's a standard prior on this kind of 156 00:09:27,195 --> 00:09:33,505 latent variable which is probabilities for critical variable. 157 00:09:33,505 --> 00:09:35,660 Then we have a conditional distribution, Zed given Theta, 158 00:09:35,660 --> 00:09:38,295 which is at Theta itself. 159 00:09:38,295 --> 00:09:43,605 Some topic of some words will be, 160 00:09:43,605 --> 00:09:47,870 for example, cats with probability Theta cats. 161 00:09:47,870 --> 00:09:52,215 Let's see what can we do about this model. 162 00:09:52,215 --> 00:09:56,329 First of all, notice that these two distributions are conjugate, 163 00:09:56,329 --> 00:10:01,870 which basically means that we can find the posterior on our Theta analytically. 164 00:10:01,870 --> 00:10:05,375 And it's really easy, you can look up the formulas on Wikipedia, 165 00:10:05,375 --> 00:10:09,875 or you can just derive them yourself but nothing complicated here, 166 00:10:09,875 --> 00:10:13,380 because of the conjugate distributions. 167 00:10:13,380 --> 00:10:15,494 Then we may try to marginalize our Theta, 168 00:10:15,494 --> 00:10:18,815 so may try to find these marginal distributions at. 169 00:10:18,815 --> 00:10:20,775 And then it looks complicated, 170 00:10:20,775 --> 00:10:22,305 it has an integral inside, 171 00:10:22,305 --> 00:10:26,057 but it turns out that since we have the posterior distribution Theta, 172 00:10:26,057 --> 00:10:29,570 we can represent it in this way. 173 00:10:29,570 --> 00:10:36,295 This formula comes just from the definition of the conditional probability. 174 00:10:36,295 --> 00:10:40,670 If you multiply both sides by the probability of Theta given Zed, 175 00:10:40,670 --> 00:10:43,480 you get the usual definition of conditional probability. 176 00:10:43,480 --> 00:10:46,875 And now you have, on the right hand side you know all the terms. 177 00:10:46,875 --> 00:10:48,772 You know the likelihood, P of Z given Theta, 178 00:10:48,772 --> 00:10:50,740 it just fit itself. 179 00:10:50,740 --> 00:10:54,230 You know that prior, your fit is just directly distribution, 180 00:10:54,230 --> 00:10:55,670 and you know the posterior 181 00:10:55,670 --> 00:10:59,225 because these distributions are conjugate and you can compute it. 182 00:10:59,225 --> 00:11:01,895 These things are easy. 183 00:11:01,895 --> 00:11:07,220 Now, we can do the same thing with Phi and W. You're going to notice that 184 00:11:07,220 --> 00:11:13,304 the prior distribution in Phi is conjugate to the conditional distribution W, 185 00:11:13,304 --> 00:11:17,510 and thus, you can easily obtain the posterior distribution in Phi. 186 00:11:17,510 --> 00:11:20,640 In the same way you can marginalize out Phi. 187 00:11:20,640 --> 00:11:22,645 You can find W given Z, 188 00:11:22,645 --> 00:11:25,550 by using this formula without 189 00:11:25,550 --> 00:11:29,365 any integration because you know the posterior analytically. 190 00:11:29,365 --> 00:11:34,180 And finally, you can multiply P of Z times P of W given 191 00:11:34,180 --> 00:11:40,640 Z to obtain a normalized version of the posterior on Z given your data. 192 00:11:40,640 --> 00:11:45,680 And now everything that's left is you may for example, 193 00:11:45,680 --> 00:11:47,880 sample from these posterior distribution. 194 00:11:47,880 --> 00:11:52,385 Notice that, we don't know actually 195 00:11:52,385 --> 00:11:54,700 the posterior exactly because we don't know 196 00:11:54,700 --> 00:11:57,835 the normalization constant and it's really expensive to compute. 197 00:11:57,835 --> 00:12:00,300 But it doesn't matter if we can keep sampling. 198 00:12:00,300 --> 00:12:02,315 We may sample from Z anyway. 199 00:12:02,315 --> 00:12:04,840 And this will give us very efficient scheme which is 200 00:12:04,840 --> 00:12:07,050 called Collapsed Gibbs sampling because we collapse 201 00:12:07,050 --> 00:12:11,665 some of the latent variables and marginalized them out analytically. 202 00:12:11,665 --> 00:12:19,810 This approach will allow you to sample only Zs and then obtain, 203 00:12:19,810 --> 00:12:25,105 and thus make a more efficient procedure overall. 204 00:12:25,105 --> 00:12:29,270 And you may ask like, if I have samples only from the posterior on Z, 205 00:12:29,270 --> 00:12:35,205 how can I find the the values of Phi, for example? 206 00:12:35,205 --> 00:12:36,985 Something I care about. 207 00:12:36,985 --> 00:12:40,545 Well, because you know all the posterior distributions you 208 00:12:40,545 --> 00:12:45,170 cannot obtain these kind of things easily from the samples from the posterior on Z. 209 00:12:45,170 --> 00:12:46,767 So the probability of Phi, U, 210 00:12:46,767 --> 00:12:48,265 and W for example, 211 00:12:48,265 --> 00:12:50,425 which you may be interested in, 212 00:12:50,425 --> 00:12:56,530 can be computed by marginalized without Z from the posterior distribution of Phi, 213 00:12:56,530 --> 00:12:58,700 U, and W, and Z. 214 00:12:58,700 --> 00:13:02,283 And this is some expected value with respect to the posterior on Z, 215 00:13:02,283 --> 00:13:07,475 U and W. So you can you can obtain this thing from samples, right? 216 00:13:07,475 --> 00:13:11,020 You can sample a few Zs from this posterior distribution and then average 217 00:13:11,020 --> 00:13:15,160 them to get this estimation, 218 00:13:15,160 --> 00:13:17,680 the posterior distribution Phi, 219 00:13:17,680 --> 00:13:19,715 and get the actual parameters Phi which are 220 00:13:19,715 --> 00:13:22,570 response to the probability which worked for each topic. 221 00:13:22,570 --> 00:13:25,035 If you're willing to know more details about this kind of 222 00:13:25,035 --> 00:13:28,315 Collapsed Gibbs sampling for LDA, 223 00:13:28,315 --> 00:13:32,155 look into the original paper that introduced that. It's really cool. 224 00:13:32,155 --> 00:13:37,885 Check out the additional reading material if you're interested. 225 00:13:37,885 --> 00:13:46,315 But the bottom line here is that if you run the Collapsed Gibbs sampling for the LDA, 226 00:13:46,315 --> 00:13:50,710 it will be even a little bit more efficient than a variational inference. 227 00:13:50,710 --> 00:13:54,880 Here on the plot, you can see the basically time on 228 00:13:54,880 --> 00:14:00,185 the X-axis or number of operations wil be formed and some measure of error on the Y-axis. 229 00:14:00,185 --> 00:14:02,035 The lower, the better. 230 00:14:02,035 --> 00:14:03,980 And you can see that the Gibbs sampling is 231 00:14:03,980 --> 00:14:06,570 performing a little bit better than variational inference, 232 00:14:06,570 --> 00:14:09,978 and also better than something called EP or Expectation Propagation, 233 00:14:09,978 --> 00:14:13,285 which we will not talk about in this course. 234 00:14:13,285 --> 00:14:16,840 The bottom line here is that we were able to obtain 235 00:14:16,840 --> 00:14:19,049 a really efficient scheme for solving 236 00:14:19,049 --> 00:14:22,478 LDA model without too much trouble of variational inference. 237 00:14:22,478 --> 00:14:28,082 We didn't have to derive large number of complicated formulas, 238 00:14:28,082 --> 00:14:30,043 we just apply the usual scheme of Gibbs sampling. 239 00:14:30,043 --> 00:14:32,360 It wasn't very efficient. 240 00:14:32,360 --> 00:14:36,860 And then we just marginalized out some parts basically which was also really easy. 241 00:14:36,860 --> 00:14:41,570 And we finally got an efficient scheme for LDA. 242 00:14:41,570 --> 00:14:43,690 One more thing here, is 243 00:14:43,690 --> 00:14:47,170 that generally Markov chain 244 00:14:47,170 --> 00:14:50,815 Monte Carlo is not very good suited for mini-batches, so for large data. 245 00:14:50,815 --> 00:14:54,850 It has to process the whole dataset in each situation, 246 00:14:54,850 --> 00:14:56,350 but sometimes it works. 247 00:14:56,350 --> 00:14:59,127 So, use all the current use for Markov chain Monte Carlo and 248 00:14:59,127 --> 00:15:02,650 use mini-batches and sometimes it fails miserably. 249 00:15:02,650 --> 00:15:10,080 But in this case, the authors just used mini-batches and it works quite nice here.