1 00:00:00,000 --> 00:00:04,572 [MUSIC] 2 00:00:04,572 --> 00:00:08,079 Up to this point, we've derived our coordinate descent algorithm for 3 00:00:08,079 --> 00:00:11,650 these squares and we've discussed what the variant is for lasso. 4 00:00:11,650 --> 00:00:15,700 But we didn't derive explicitly why our coordinate descent algorithm for 5 00:00:15,700 --> 00:00:18,030 lasso has that soft thresholding form. 6 00:00:19,190 --> 00:00:22,540 And for completeness, we're going to include the derivation but 7 00:00:22,540 --> 00:00:24,420 as an optional video. 8 00:00:24,420 --> 00:00:27,610 And I really want to emphasize that this video is optional. 9 00:00:27,610 --> 00:00:29,370 The material is quite advanced. 10 00:00:29,370 --> 00:00:33,430 It's actually material that's typically taught at the level of a Ph.D. 11 00:00:33,430 --> 00:00:34,770 or grad level course. 12 00:00:36,060 --> 00:00:38,570 So, watch it at your own risk. 13 00:00:38,570 --> 00:00:41,580 We think it's still very interesting 14 00:00:41,580 --> 00:00:46,470 if you want to see what's under the hood in the lasso solver. 15 00:00:46,470 --> 00:00:50,326 But that's our little word of warning before we delve into this content, 16 00:00:50,326 --> 00:00:54,680 because again this is an example where the level of what we're gonna present in this 17 00:00:54,680 --> 00:00:57,853 video goes much more technical than what we're assuming for 18 00:00:57,853 --> 00:00:59,684 the rest of this specialization. 19 00:01:00,809 --> 00:01:01,540 Okay. 20 00:01:01,540 --> 00:01:06,720 So, let's get into optimizing our lasso objective which we've written here And 21 00:01:06,720 --> 00:01:10,040 we're doing it again one coordinate at a time to derive or 22 00:01:10,040 --> 00:01:11,830 coordinate descend algorithm. 23 00:01:11,830 --> 00:01:17,440 And in this derivation we're gonna use unnormalized features so 24 00:01:17,440 --> 00:01:19,950 we're not doing the normalization. 25 00:01:19,950 --> 00:01:24,710 This is just our typical h that's appearing here. 26 00:01:24,710 --> 00:01:29,200 And we're doing that because that's the most general case we can derive and 27 00:01:29,200 --> 00:01:33,900 the normalized case follows directly from this derivation. 28 00:01:35,380 --> 00:01:39,380 Okay, so to start with let's look at the partial of our residual summer squares 29 00:01:39,380 --> 00:01:44,420 term with respect to wj and we did this derivation when we're looking at deriving 30 00:01:44,420 --> 00:01:50,000 corner to center lease squares, but we did it for Normalized features so, 31 00:01:50,000 --> 00:01:53,410 let's do it now for un-normalized features where remember what we're gonna do is just 32 00:01:53,410 --> 00:01:59,850 pull out this w,j term, so we're gonna get sum i equals one to n. 33 00:02:01,170 --> 00:02:06,502 H, j of x, i and I'm gonna skip a step 34 00:02:06,502 --> 00:02:11,490 assuming that you remember what we 35 00:02:11,490 --> 00:02:16,822 did On that previous video where I'm 36 00:02:16,822 --> 00:02:21,466 separating out K not equal to J, 37 00:02:21,466 --> 00:02:26,282 W K H K X I from this term where we get 38 00:02:26,282 --> 00:02:31,646 the The wjhj of Xi and so 39 00:02:31,646 --> 00:02:36,754 when we do this multiplication we're 40 00:02:36,754 --> 00:02:43,470 gonna get minus two sum over i equals one to n hjXi. 41 00:02:45,630 --> 00:02:49,950 Times YI minus Sum K not equal to JWKHKXI. 42 00:02:55,060 --> 00:02:57,830 And here we're gonna get plus two. 43 00:02:57,830 --> 00:03:02,525 Again the WJ comes out, we get Sum I equals one to N. 44 00:03:02,525 --> 00:03:05,930 And hj(xi) squared. 45 00:03:05,930 --> 00:03:09,610 And when we were talking about normalized features, 46 00:03:09,610 --> 00:03:15,250 we said the sum of the hj squared was equal to one. 47 00:03:15,250 --> 00:03:18,760 But now we're looking at unnormalized features. 48 00:03:18,760 --> 00:03:21,230 So let's define some terms. 49 00:03:23,500 --> 00:03:27,260 Again, we're gonna call this term, even though in this case we're looking at 50 00:03:27,260 --> 00:03:29,960 unnormalized features, we're still gonna call this row j. 51 00:03:31,440 --> 00:03:37,260 And here, now, we have to define what we're calling this normalizer, 52 00:03:37,260 --> 00:03:41,580 and we're gonna call this zi. 53 00:03:41,580 --> 00:03:43,516 I mean, sorry, not zi, zj. 54 00:03:43,516 --> 00:03:47,513 Okay So, 55 00:03:47,513 --> 00:03:53,020 the result. 56 00:03:53,020 --> 00:03:58,564 I go back to my blue colors, we get -2 row J just like before, 57 00:03:58,564 --> 00:04:03,562 but now we get +2 WjZj. 58 00:04:05,050 --> 00:04:09,510 So, that completes our residual sum of squares term. 59 00:04:09,510 --> 00:04:12,480 Now, let's turn to our L1 penalty. 60 00:04:12,480 --> 00:04:16,090 And this is where things become more complicated. 61 00:04:16,090 --> 00:04:22,120 So, in particular what's the partial with respect to the absolute value of wj? 62 00:04:23,200 --> 00:04:27,930 Remember, that's all the other terms we're thinking of as held fixed. 63 00:04:27,930 --> 00:04:30,710 So, we're just looking at this Wj component. 64 00:04:30,710 --> 00:04:33,810 Well, this is where we get these derivatives of 65 00:04:33,810 --> 00:04:36,110 absolute values that become problematic. 66 00:04:36,110 --> 00:04:39,090 Remember here we have derivative. 67 00:04:41,840 --> 00:04:47,972 Is equal to minus one and here we have derivative equals 68 00:04:47,972 --> 00:04:53,049 plus one and here we have this problem point. 69 00:04:58,668 --> 00:05:03,143 And we mentioned that instead of thinking about Gradients or these partials, 70 00:05:03,143 --> 00:05:05,830 we think about sub gradients. 71 00:05:05,830 --> 00:05:12,090 Okay, so now let's talk about these sub gradients and formalize this a bit. 72 00:05:12,090 --> 00:05:15,140 But let's first mention that gradients, 73 00:05:15,140 --> 00:05:18,720 we can think of as lower bounding convex functions. 74 00:05:18,720 --> 00:05:24,670 So, if I look at My convex function at some points A and B. 75 00:05:24,670 --> 00:05:32,410 If I take the gradient at my point a that's this tangent plane. 76 00:05:32,410 --> 00:05:38,240 So, this is gradient of G at A. 77 00:05:38,240 --> 00:05:43,730 Then what I have is I have that G at B. 78 00:05:43,730 --> 00:05:47,804 So, just to be clear, this is g(a), g(b). 79 00:05:47,804 --> 00:05:53,444 So, g(b) is greater than or equal to g(a) 80 00:05:56,869 --> 00:06:01,886 plus gradient of 81 00:06:01,886 --> 00:06:06,687 g(a) times (b-a). 82 00:06:06,687 --> 00:06:11,360 And importantly, 83 00:06:11,360 --> 00:06:15,740 gradients are unique at x if the function is differentiable at x. 84 00:06:17,220 --> 00:06:22,360 While subgradients generalize this idea of gradients to non-differentiable points. 85 00:06:22,360 --> 00:06:27,270 And a gradient is going to be any plane that lower bounds this function. 86 00:06:27,270 --> 00:06:31,815 So, sorry not a gradient, if I said a gradient i mean a sub-gradient. 87 00:06:31,815 --> 00:06:33,810 So, a sub-gradient is really a set, 88 00:06:33,810 --> 00:06:38,350 its a set of all the plane that lower bound a function. 89 00:06:38,350 --> 00:06:44,930 So, we're gonna saw that v is in our set 90 00:06:44,930 --> 00:06:51,320 which we denote by this curly d of g at a point x. 91 00:06:51,320 --> 00:06:56,348 So, this is our sub radiant of 92 00:06:56,348 --> 00:07:03,388 g at x If we have that g of b is greater than, 93 00:07:03,388 --> 00:07:07,611 or equal to g of a, plus, 94 00:07:07,611 --> 00:07:13,043 in place of this gradient here, 95 00:07:13,043 --> 00:07:16,280 we're writing V. 96 00:07:16,280 --> 00:07:21,300 So, here this was the one function that lower 97 00:07:21,300 --> 00:07:26,910 bounded our Sorry, the one plane that lower bounded our function. 98 00:07:26,910 --> 00:07:31,640 And here this is one of the planes that lower bounds our function. 99 00:07:31,640 --> 00:07:33,340 And again we have b- a. 100 00:07:33,340 --> 00:07:37,560 So, this is the definition of a sub-gradient. 101 00:07:37,560 --> 00:07:42,570 And let's look at it in the context of this absolute value function. 102 00:07:42,570 --> 00:07:47,810 So what are all planes that lower bound this absolute value function. 103 00:07:47,810 --> 00:07:53,260 Well, All the planes that I'm 104 00:07:53,260 --> 00:07:58,566 gonna draw lower bound this absolute value function, 105 00:07:58,566 --> 00:08:02,948 there's this plane, there's this plane, 106 00:08:02,948 --> 00:08:07,110 there's This plane with positive slope. 107 00:08:12,320 --> 00:08:17,880 And I could fill this space here with all these planes. 108 00:08:17,880 --> 00:08:20,530 And what are the slopes of these planes? 109 00:08:21,950 --> 00:08:28,647 So, here we know that The slope is equal to -1 and 110 00:08:28,647 --> 00:08:33,300 here the slope is is equal to +1. 111 00:08:33,300 --> 00:08:39,210 And we see that anything that has a slope in the range of -1 to 1, any 112 00:08:39,210 --> 00:08:44,390 line in the range -1 to 1, is gonna lower bound this absolutely value function. 113 00:08:44,390 --> 00:08:48,690 So, For absolute value of x, 114 00:08:48,690 --> 00:08:52,715 we have that v is in minus one to one. 115 00:08:52,715 --> 00:09:00,600 So, minus one to one represents our subgradient of the absolute value of x. 116 00:09:02,990 --> 00:09:06,660 So now that we've had our detour into the world of subgradients. 117 00:09:06,660 --> 00:09:12,450 We can start talking- instead of the partial of this L1 term, we can talk about 118 00:09:12,450 --> 00:09:18,670 the subgradient of this L1 term where here we get Lambda times the subgradient. 119 00:09:18,670 --> 00:09:21,790 We already know that subgradient are the absolute value is 120 00:09:21,790 --> 00:09:24,060 the range minus one to one. 121 00:09:25,800 --> 00:09:31,250 So, lambda times the subgradient of the absolute value is going to be 122 00:09:31,250 --> 00:09:35,570 Minus lambda to lambda when w j is equal to 0. 123 00:09:35,570 --> 00:09:42,350 Remember a subgradient is defined at that problem point w j equal 0, but 124 00:09:42,350 --> 00:09:48,250 in the case where the derivative exists or that partial exists it's just going 125 00:09:48,250 --> 00:09:54,320 to be lambda times in the case of w less than 0 we had minus 1. 126 00:09:54,320 --> 00:09:58,030 So this is lambda times minus 1. 127 00:09:58,030 --> 00:10:03,030 Here we have lambda times minus 1 to 1 and 128 00:10:03,030 --> 00:10:09,480 the case where we're in this positive half plane we get, lambda times 1. 129 00:10:09,480 --> 00:10:15,730 Okay so this is our complete lambda times the sub-gradient of our 130 00:10:15,730 --> 00:10:21,090 L1 objective with respect to wj And 131 00:10:21,090 --> 00:10:25,100 now that we have the partial of our residual sum of squares term and 132 00:10:25,100 --> 00:10:29,280 the subgradient of our L1 term, we can put it all together and 133 00:10:29,280 --> 00:10:35,560 get our subgradient of our entire lasso cost, with respect to wj, 134 00:10:35,560 --> 00:10:43,250 and here this part, this is from our residual sum of squares. 135 00:10:43,250 --> 00:10:49,840 Whereas this part is from the L1 penalty. 136 00:10:51,990 --> 00:10:55,590 Or really landa times the L1 penalty. 137 00:10:58,460 --> 00:11:03,980 And when we put these things together we get three different cases. 138 00:11:03,980 --> 00:11:06,640 Because of the three cases for the L1 objective. 139 00:11:07,770 --> 00:11:10,106 So, we get 2zj that normalizer. 140 00:11:10,106 --> 00:11:12,949 Wj-rho j- lambda. 141 00:11:12,949 --> 00:11:15,530 It won't read all of the different cases. 142 00:11:15,530 --> 00:11:21,170 But this now is our subgradient of our entire Lasso solution. 143 00:11:22,460 --> 00:11:26,340 Okay, well, let's not get lost in the weeds here. 144 00:11:26,340 --> 00:11:30,190 Let's pop up a level and say remember before we would take the gradient and 145 00:11:30,190 --> 00:11:34,500 set it equal to 0, or in the coordinate descent algorithm we talked about 146 00:11:34,500 --> 00:11:38,780 taking the partial with respect to one coordinate 147 00:11:38,780 --> 00:11:43,260 then setting that equal to zero to get the update for that dimension. 148 00:11:43,260 --> 00:11:47,160 Here now instead of this partial derivative we're 149 00:11:47,160 --> 00:11:50,320 taking the subgradient and we have to set it equal to zero. 150 00:11:50,320 --> 00:11:53,720 To get the update for this specific jth coordinate. 151 00:11:53,720 --> 00:11:58,860 So, let's do that, so we've taken our subgradient, we're setting it equal to 0. 152 00:11:58,860 --> 00:12:02,110 Again we have three different cases we need to consider. 153 00:12:02,110 --> 00:12:07,790 So, in Case 1, where wj is less than 0, let's solve for w at j. 154 00:12:07,790 --> 00:12:13,320 You get w at j Equals 2 rho 155 00:12:13,320 --> 00:12:18,200 j + lambda divided by 2 z j, 156 00:12:18,200 --> 00:12:25,130 which I'm just gonna rewrite as rho j + lambda over 2, divided by z j. 157 00:12:25,130 --> 00:12:30,316 So, I've multiplied the top and bottom by one-half here. 158 00:12:30,316 --> 00:12:36,590 Okay, but to be in this case, 159 00:12:36,590 --> 00:12:43,250 to have W to have J less than zero, we need a constraint on row J. 160 00:12:43,250 --> 00:12:50,590 So if row J is less than 161 00:12:50,590 --> 00:12:55,870 minus lambda over two, remember this is that correlation term. 162 00:12:55,870 --> 00:12:59,830 If that and that's something we can compute because that's a function of 163 00:12:59,830 --> 00:13:04,180 all the other variables except for wj. 164 00:13:04,180 --> 00:13:08,840 So, if ro j is less than minus lambda over 2 then we know that 165 00:13:08,840 --> 00:13:13,440 w hat j will be less than 0 according to this formula. 166 00:13:14,940 --> 00:13:19,224 Okay, then we get to the second case, which is wj = 0. 167 00:13:19,224 --> 00:13:26,450 In that case, where you've already solved for w hatch a, there's only one solution. 168 00:13:26,450 --> 00:13:29,730 But in order to have that be the optimum, 169 00:13:29,730 --> 00:13:35,150 we know that this subgradient when wj = 0 170 00:13:35,150 --> 00:13:40,180 The subgradient has to contain zero otherwise we would never get that, 171 00:13:41,780 --> 00:13:45,910 this is the case that is equal to zero that it's an optimum. 172 00:13:45,910 --> 00:13:51,480 So, we need for this range to contain zero, 173 00:13:51,480 --> 00:13:56,308 so that w hat j equals 0 is an optimum. 174 00:13:59,209 --> 00:14:00,190 Of our objective. 175 00:14:01,930 --> 00:14:08,750 And for that to be true, we need minus 2 row j plus lambda to be greater than 0. 176 00:14:08,750 --> 00:14:12,970 So, we need this upper term of the interval greater than 0, 177 00:14:12,970 --> 00:14:17,320 which is equivalent to saying that Rho J is less 178 00:14:17,320 --> 00:14:22,380 than lambda over two and we need this bottom interval to be less than zero. 179 00:14:22,380 --> 00:14:28,930 So, minus two rho J minus lambda less than zero which is equivalent to saying 180 00:14:28,930 --> 00:14:37,270 rho J is greater than Minus lambda over two. 181 00:14:37,270 --> 00:14:43,914 And if we put these together what this is saying is that row j is less 182 00:14:43,914 --> 00:14:49,080 than lambda over two and greater than 183 00:14:51,080 --> 00:14:56,830 minus lambda over two And actually we could put the equal sign here. 184 00:14:59,710 --> 00:15:03,190 So, let's just do, so it has to be less than or equal to, less than or 185 00:15:03,190 --> 00:15:06,930 equal to, less than or equal to. 186 00:15:06,930 --> 00:15:09,080 Okay, and our final case, 187 00:15:09,080 --> 00:15:13,900 let's just quickly work through it we get w hat j equals. 188 00:15:13,900 --> 00:15:19,860 Row l- lambda over 2 divided by Zj 189 00:15:22,860 --> 00:15:28,980 and in order to have W hat J be greater than 0, 190 00:15:28,980 --> 00:15:36,130 we need row j would be greater than lambda over two. 191 00:15:37,710 --> 00:15:41,420 So, let me just talk through this kind of in the other direction now that we've done 192 00:15:41,420 --> 00:15:46,770 the derivation which is saying if rho J is less than minus lambda over two, 193 00:15:48,070 --> 00:15:52,620 then we'll set W hat J as follows. 194 00:15:52,620 --> 00:15:56,270 If row j is in this interval, 195 00:15:56,270 --> 00:16:01,290 we'll set w hat j equal to 0, and if row j is greater than lambda over 2, 196 00:16:01,290 --> 00:16:05,420 we're gonna set w hat j as this third case. 197 00:16:06,620 --> 00:16:10,870 Okay, so, this slide just summarizes what I just said. 198 00:16:12,200 --> 00:16:18,370 So, this is our optimal 1 D optimization for this lasso objective. 199 00:16:20,140 --> 00:16:24,110 So, let's talk about this more general form of the soft thresholding rule for 200 00:16:24,110 --> 00:16:28,810 lasso in the case of our unnormalized features. 201 00:16:30,300 --> 00:16:35,670 So, remember for our normalized features, there wasn't this CJ here And 202 00:16:35,670 --> 00:16:40,640 what we ended up with for our least square solution when lambda 203 00:16:40,640 --> 00:16:45,630 was equal to 0 was just this line, w-hat j equals rho j. 204 00:16:46,680 --> 00:16:50,210 But now what does our least squares line look like? 205 00:16:50,210 --> 00:16:55,790 Well again, we can just set lambda equal to 0, and we see that this Lee squares 206 00:16:55,790 --> 00:17:02,600 line w hat lee squares is equal to row j over z j. 207 00:17:02,600 --> 00:17:07,480 Remember z j is that normalizer so I mean over the square of all of our features. 208 00:17:07,480 --> 00:17:11,550 So, that number will be positive and it's typically gonna be larger than one. 209 00:17:11,550 --> 00:17:16,190 Potentially much larger than one So, relative to a slope, 210 00:17:16,190 --> 00:17:20,140 which is a 45 degree angle slope of 1, 211 00:17:20,140 --> 00:17:26,100 I'm saying that this line is shrunk more this way, 212 00:17:26,100 --> 00:17:29,660 typically, in the case of unnormalized features. 213 00:17:29,660 --> 00:17:33,630 And then, when I look at my lasso solution, w hat. 214 00:17:35,180 --> 00:17:37,590 Lasso in this case. 215 00:17:37,590 --> 00:17:40,990 Again, in the range minus lambda over 2. 216 00:17:40,990 --> 00:17:44,280 Sorry, that is clearly not minus lambda over 2. 217 00:17:44,280 --> 00:17:49,530 This is minus lambda over 2, to lambda over 2. 218 00:17:49,530 --> 00:17:55,683 I get this thresholding of the coefficients exactly to zero, 219 00:17:55,683 --> 00:18:02,072 relative to My least square solution, and outside that range, 220 00:18:02,072 --> 00:18:07,399 the difference between the least square solution and 221 00:18:07,399 --> 00:18:12,604 my lasso solution, is that my coefficients are each 222 00:18:12,604 --> 00:18:17,137 shrunk by an amount lambda over lambda 2zj. 223 00:18:19,411 --> 00:18:20,910 Okay. 224 00:18:20,910 --> 00:18:25,900 But remember that rho_j here as compared to when we talked about 225 00:18:25,900 --> 00:18:28,370 normalized features was defined differently. 226 00:18:28,370 --> 00:18:31,340 It was defined in terms of our unnormalized features. 227 00:18:32,510 --> 00:18:37,800 So, for the same value of lambda that you would use with normalized features 228 00:18:37,800 --> 00:18:40,490 you're getting a different relationship here. 229 00:18:40,490 --> 00:18:42,600 A different range where things are set to zero. 230 00:18:44,285 --> 00:18:47,980 S,o in summary, we've derived this coordinate descent algorithm for 231 00:18:47,980 --> 00:18:51,350 lasso in the case of unormalized features. 232 00:18:51,350 --> 00:18:55,870 And in particular, the key insight we had was instead of just taking the partial of 233 00:18:55,870 --> 00:19:00,260 our objective with respect to WJ we had to take the subgradient of our 234 00:19:00,260 --> 00:19:04,720 objective with respect to Wj, and that's what leads to these three different cases, 235 00:19:04,720 --> 00:19:09,720 because the gradient itself is defined for 236 00:19:10,750 --> 00:19:15,882 every value of Wj, except for this one critical point, Wj = 0. 237 00:19:15,882 --> 00:19:20,741 But in particular we also had a lot of insight into how this soft 238 00:19:20,741 --> 00:19:25,518 thresholding gives us the sparsity in our lasso solutions. 239 00:19:25,518 --> 00:19:28,099 [MUSIC]