@@ -8074,109 +8074,122 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
80748074
80758075
80768076foreach_unary_op_db: List[OpInfo] = [
8077- ForeachFuncInfo('exp', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8078- ForeachFuncInfo('acos', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8079- ForeachFuncInfo('asin', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8080- ForeachFuncInfo('atan', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8081- ForeachFuncInfo('cos', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8082- ForeachFuncInfo('cosh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8083- ForeachFuncInfo('log', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8084- ForeachFuncInfo('log10', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8085- ForeachFuncInfo('log2', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8086- ForeachFuncInfo('tan', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8087- ForeachFuncInfo('tanh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8088- ForeachFuncInfo('sin', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8089- ForeachFuncInfo('sinh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8077+ ForeachFuncInfo('exp', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True ),
8078+ ForeachFuncInfo('acos', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True ),
8079+ ForeachFuncInfo('asin', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True ),
8080+ ForeachFuncInfo('atan', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True ),
8081+ ForeachFuncInfo('cos', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True ),
8082+ ForeachFuncInfo('cosh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True ),
8083+ ForeachFuncInfo('log', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True ),
8084+ ForeachFuncInfo('log10', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True ),
8085+ ForeachFuncInfo('log2', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True ),
8086+ ForeachFuncInfo('tan', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True ),
8087+ ForeachFuncInfo('tanh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True ),
8088+ ForeachFuncInfo('sin', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True ),
8089+ ForeachFuncInfo('sinh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True ),
80908090
80918091 ForeachFuncInfo(
80928092 'neg',
80938093 dtypes=all_types_and_complex(),
80948094 dtypesIfCUDA=all_types_and_complex(),
80958095 sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8096+ supports_autograd=True,
80968097 ),
80978098
80988099 ForeachFuncInfo(
80998100 'sqrt',
81008101 dtypes=floating_and_complex_types_and(torch.bfloat16),
81018102 dtypesIfCUDA=floating_and_complex_types_and(torch.half),
81028103 sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8104+ supports_autograd=True,
81038105 ),
81048106
81058107 ForeachFuncInfo(
81068108 'ceil',
81078109 dtypes=all_types_and(torch.bfloat16),
81088110 dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
81098111 sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8112+ supports_autograd=True,
81108113 ),
81118114
81128115 ForeachFuncInfo(
81138116 'erf',
81148117 dtypes=floating_types_and(torch.bfloat16),
81158118 dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
81168119 sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8120+ supports_autograd=True,
81178121 ),
81188122
81198123 ForeachFuncInfo(
81208124 'erfc',
81218125 dtypes=floating_types_and(torch.bfloat16),
81228126 dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
81238127 sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8128+ supports_autograd=True,
81248129 ),
81258130
81268131 ForeachFuncInfo(
81278132 'expm1',
81288133 dtypes=floating_types_and(torch.bfloat16),
81298134 dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
81308135 sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8136+ supports_autograd=True,
81318137 ),
81328138
81338139 ForeachFuncInfo(
81348140 'floor',
81358141 dtypes=all_types_and(torch.bfloat16),
81368142 dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
81378143 sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8144+ supports_autograd=True,
81388145 ),
81398146
81408147 ForeachFuncInfo(
81418148 'log1p',
81428149 dtypes=floating_and_complex_types_and(torch.bfloat16),
81438150 dtypesIfCUDA=floating_and_complex_types_and(torch.half),
81448151 sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8152+ supports_autograd=True,
81458153 ),
81468154
81478155 ForeachFuncInfo(
81488156 'round',
81498157 dtypes=all_types_and(torch.bfloat16),
81508158 dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
81518159 sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8160+ supports_autograd=True,
81528161 ),
81538162
81548163 ForeachFuncInfo(
81558164 'frac',
81568165 dtypes=floating_types_and(torch.bfloat16),
81578166 dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
81588167 sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8168+ supports_autograd=True,
81598169 ),
81608170
81618171 ForeachFuncInfo(
81628172 'reciprocal',
81638173 dtypes=floating_types_and(torch.bfloat16),
81648174 dtypesIfCUDA=floating_types_and(torch.half),
81658175 sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8176+ supports_autograd=True,
81668177 ),
81678178
81688179 ForeachFuncInfo(
81698180 'sigmoid',
81708181 dtypes=floating_types_and(torch.bfloat16),
81718182 dtypesIfCUDA=floating_types_and(torch.half),
81728183 sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8184+ supports_autograd=True,
81738185 ),
81748186
81758187 ForeachFuncInfo(
81768188 'trunc',
81778189 dtypes=all_types_and(torch.bfloat16),
81788190 dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
81798191 sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8192+ supports_autograd=True,
81808193 ),
81818194
81828195 ForeachFuncInfo(
@@ -8186,6 +8199,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
81868199 supports_forward_ad=True,
81878200 supports_fwgrad_bwgrad=True,
81888201 sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8202+ supports_autograd=True,
81898203 ),
81908204]
81918205
0 commit comments