@@ -85,6 +85,30 @@ def _run_on_loss(self, net, param_init_net, param, grad=None):
8585 return output_blob
8686
8787
88+ class L1NormTrimmed (Regularizer ):
89+ """
90+ The Trimmed Lasso: Sparsity and Robustness. https://arxiv.org/abs/1708.04527
91+ """
92+ def __init__ (self , reg_lambda , k ):
93+ super (L1NormTrimmed , self ).__init__ ()
94+ assert reg_lambda >= 0 , "factor ahead of regularization should be 0 or positive"
95+ assert isinstance (k , int ), "k should be an interger as expected #. after selection"
96+ assert k >= 1 , "k should be larger than 1"
97+
98+ self .reg_lambda = reg_lambda
99+ self .k = k
100+
101+ def _run_on_loss (self , net , param_init_net , param , grad = None ):
102+ output_blob = net .NextScopedBlob (param + "_l1_trimmed_regularization" )
103+ abs = net .Abs (param , [net .NextScopedBlob ("abs" )])
104+ sum_abs = net .SumElements ([abs ], [net .NextScopedBlob ("sum_abs" )], average = False )
105+ topk , _ , _ = net .TopK (abs , [net .NextScopedBlob ("topk" ), 'id' , 'flat_id' ], k = self .k )
106+ topk_sum = net .SumElements ([topk ], [net .NextScopedBlob ("topk_sum" )], average = False )
107+ net .Sub ([sum_abs , topk_sum ], [output_blob ])
108+ net .Scale ([output_blob ], [output_blob ], scale = self .reg_lambda )
109+ return output_blob
110+
111+
88112class L2Norm (Regularizer ):
89113 def __init__ (self , reg_lambda ):
90114 super (L2Norm , self ).__init__ ()
@@ -117,6 +141,31 @@ def _run_on_loss(self, net, param_init_net, param, grad=None):
117141 return output_blob
118142
119143
144+ class ElasticNetL1NormTrimmed (Regularizer ):
145+ def __init__ (self , l1 , l2 , k ):
146+ super (ElasticNetL1NormTrimmed , self ).__init__ ()
147+ self .l1 = l1
148+ self .l2 = l2
149+ self .k = k
150+
151+ def _run_on_loss (self , net , param_init_net , param , grad = None ):
152+ output_blob = net .NextScopedBlob (param + "_elastic_net_l1_trimmed_regularization" )
153+ l2_blob = net .NextScopedBlob (param + "_l2_blob" )
154+ net .LpNorm ([param ], [l2_blob ], p = 2 )
155+ net .Scale ([l2_blob ], [l2_blob ], scale = self .l2 )
156+
157+ l1_blob = net .NextScopedBlob (param + "_l1_blob" )
158+ abs = net .Abs (param , [net .NextScopedBlob ("abs" )])
159+ sum_abs = net .SumElements ([abs ], [net .NextScopedBlob ("sum_abs" )], average = False )
160+ topk , _ , _ = net .TopK (abs , [net .NextScopedBlob ("topk" ), 'id' , 'flat_id' ], k = self .k )
161+ topk_sum = net .SumElements ([topk ], [net .NextScopedBlob ("topk_sum" )], average = False )
162+ net .Sub ([sum_abs , topk_sum ], [l1_blob ])
163+ net .Scale ([l1_blob ], [l1_blob ], scale = self .l1 )
164+
165+ net .Add ([l1_blob , l2_blob ], [output_blob ])
166+ return output_blob
167+
168+
120169class MaxNorm (Regularizer ):
121170 def __init__ (self , norm = 1.0 ):
122171 super (MaxNorm , self ).__init__ ()
0 commit comments