Skip to content

Make Adam, AdamW work with nonzero-dim Tensor betas#149939

Closed
zeshengzong wants to merge 6 commits intopytorch:mainfrom
zeshengzong:opt/optim/sgd
Closed

Make Adam, AdamW work with nonzero-dim Tensor betas#149939
zeshengzong wants to merge 6 commits intopytorch:mainfrom
zeshengzong:opt/optim/sgd

Conversation

@zeshengzong
Copy link
Contributor

@zeshengzong zeshengzong commented Mar 25, 2025

Fixes #147921

Changes

  • Convert tensor betas using _to_scalar
  • Change annotation of betas param
  • Change param type in docs

Test Result

pytest -s test/test_optim.py -k test_tensor_lr -vv

image

image

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149939

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit 38a5a6d with merge base f11ac80 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zeshengzong zeshengzong marked this pull request as ready for review March 25, 2025 07:36
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 26, 2025
@albanD albanD removed their request for review April 9, 2025 16:58
"betas": (torch.tensor([[[0.9]]]), torch.tensor([[0.99]])),
"amsgrad": False,
"capturable": True,
"fused": True,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @janeyx99 , looks like fused=True failed with complex type due to backend not supported yet, is there a way to control input type in test and skip complex type? Thanks

Failed in here

if not (p.device.type in fused_supported_devices and torch.is_floating_point(p)):
raise RuntimeError(
"`fused=True` requires all the params to be floating point Tensors of "
f"supported devices: {fused_supported_devices} but {p.dtype} and {p.device.type}"

pytest -s test/test_optim.py -k test_complex_2d_AdamW_cuda_complex64

______________________________________________________________________________________________ TestOptimRenewedCUDA.test_complex_2d_AdamW_cuda_complex64 _______________________________________________________________________________________________
Traceback (most recent call last):
  File "/home/zong/miniconda3/envs/torch/lib/python3.12/unittest/case.py", line 58, in testPartExecutor
    yield
  File "/home/zong/miniconda3/envs/torch/lib/python3.12/unittest/case.py", line 634, in run
    self._callTestMethod(testMethod)
  File "/home/zong/miniconda3/envs/torch/lib/python3.12/unittest/case.py", line 589, in _callTestMethod
    if method() is not None:
       ^^^^^^^^
  File "/home/zong/code/pytorch/torch/testing/_internal/common_utils.py", line 3154, in wrapper
    method(*args, **kwargs)
  File "/home/zong/code/pytorch/torch/testing/_internal/common_utils.py", line 3154, in wrapper
    method(*args, **kwargs)
  File "/home/zong/code/pytorch/torch/testing/_internal/common_device_type.py", line 446, in instantiated_test
    raise rte
  File "/home/zong/code/pytorch/torch/testing/_internal/common_device_type.py", line 426, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/testing/_internal/common_utils.py", line 1612, in wrapper
    fn(*args, **kwargs)
  File "/home/zong/code/pytorch/torch/testing/_internal/common_optimizers.py", line 226, in test_wrapper
    return test(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/testing/_internal/common_device_type.py", line 1215, in dep_fn
    return fn(slf, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/test/test_optim.py", line 664, in test_complex_2d
    optim1.step()
  File "/home/zong/code/pytorch/torch/optim/optimizer.py", line 507, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/optim/optimizer.py", line 80, in _use_grad
    ret = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/optim/adam.py", line 237, in step
    has_complex = self._init_group(
                  ^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/optim/adam.py", line 163, in _init_group
    _device_dtype_check_for_fused(p)
  File "/home/zong/code/pytorch/torch/optim/optimizer.py", line 197, in _device_dtype_check_for_fused
    raise RuntimeError(
RuntimeError: `fused=True` requires all the params to be floating point Tensors of supported devices: ['mps', 'cuda', 'xpu', 'hpu', 'cpu', 'privateuseone'] but torch.complex64 and cuda

To execute this test, run the following from the base repo dir:
    python test/test_optim.py TestOptimRenewedCUDA.test_complex_2d_AdamW_cuda_complex64

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
=============================================================================================================== short test summary info ================================================================================================================
FAILED [2.1007s] test/test_optim.py::TestOptimRenewedCUDA::test_complex_2d_AdamW_cuda_complex64 - RuntimeError: `fused=True` requires all the params to be floating point Tensors of supported devices: ['mps', 'cuda', 'xpu', 'hpu', 'cpu', 'privateuseone'] but torch.complex64 and cuda

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, we shouldn't specify "fused" in these OptimInfos after all--that was my bad. (Please delete this config)

What I care about is that the fused optimizer is tested with these betas. Can you confirm such a configuration with Tensor betas and fused is tested in one of the existing fused tests?

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you show some proof that tensor betas + fused AdamW has run in a test? That is the last thing blocking this PR for me

@zeshengzong
Copy link
Contributor Author

Can you show some proof that tensor betas + fused AdamW has run in a test? That is the last thing blocking this PR for me

Got it, let me check it, thanks!

@zeshengzong
Copy link
Contributor Author

Compiler seems unhappy about the change, need resolve that first. 😿

@zeshengzong
Copy link
Contributor Author

Test method invoke _get_optim_inputs_including_global_cliquey_kwargs which will enumerate fused and other flags to generate test-case with combine OptimizerInput and fused.

pytorch/test/test_optim.py

Lines 320 to 330 in c1b7dbc

def test_tensor_lr(self, device, dtype, optim_info, num_dim):
optim_cls = optim_info.optim_cls
lr_devices = [device]
if _get_device_type(device) != "cpu":
lr_devices.append("cpu")
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)

for flag in supported_impls:
new_kwargs = deepcopy(base_kwargs)
new_kwargs[flag] = True
all_optim_inputs.append(
OptimizerInput(
params=None, kwargs=new_kwargs, desc=f"{optim_input.desc} & {flag}"
)
)

So, there is a test case for tensor betas and fused=True, when add a OptimizerInput with tensor betas. Thanks!

image

@@ -84,6 +84,7 @@ def __init__(
)
if betas[1].numel() != 1:
raise ValueError("Tensor betas[1] must be 1-element")
betas = tuple(map(_to_scalar, betas))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a convert in here, for tests create Adam using constructor and with compile. When invoke a compiled version of adam step method, it will skip step implementation in below and going to a compiled triton code (also skip _to_scalar in step following methods).
a2ef67a22f83636429854b2e7ea7198b

So the betas note converted into 0 dim tensor before invoke triton kernel, got errors.

ad11eb368234435b2465ed08ec3a5db8

ERROR: test_correctness_Adam_use_closure_False_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA.test_correctness_Adam_use_closure_False_cuda_float32)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/zong/code/pytorch/torch/_inductor/runtime/static_cuda_launcher.py", line 228, in run
    _StaticCudaLauncher._launch_kernel(
RuntimeError: CUDA driver error: invalid argument

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/zong/code/pytorch/torch/testing/_internal/common_utils.py", line 3141, in wrapper
    method(*args, **kwargs)
  File "/home/zong/code/pytorch/torch/testing/_internal/common_utils.py", line 3141, in wrapper
    method(*args, **kwargs)
  File "/home/zong/code/pytorch/torch/testing/_internal/common_device_type.py", line 446, in instantiated_test
    raise rte
  File "/home/zong/code/pytorch/torch/testing/_internal/common_device_type.py", line 426, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/testing/_internal/common_optimizers.py", line 226, in test_wrapper
    return test(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/testing/_internal/common_device_type.py", line 1215, in dep_fn
    return fn(slf, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/testing/_internal/common_device_type.py", line 1215, in dep_fn
    return fn(slf, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/test/inductor/test_compiled_optimizers.py", line 664, in test_correctness
    fn()
  File "/home/zong/code/pytorch/torch/_dynamo/eval_frame.py", line 699, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/test/inductor/test_compiled_optimizers.py", line 654, in fn
    opt_compiled.step()
  File "/home/zong/code/pytorch/torch/optim/lr_scheduler.py", line 131, in wrapper
    return func.__get__(opt, opt.__class__)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/optim/optimizer.py", line 506, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/optim/optimizer.py", line 80, in _use_grad
    ret = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/optim/adam.py", line 213, in step
    @_use_grad_for_differentiable
  File "/home/zong/code/pytorch/torch/_dynamo/eval_frame.py", line 893, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/_functorch/aot_autograd.py", line 1231, in forward
    return compiled_fn(full_args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 338, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
                            ^^^^^^^
  File "/home/zong/code/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 696, in inner_fn
    outs = compiled_fn(args)
           ^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 502, in wrapper
    return compiled_fn(runtime_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/_inductor/output_code.py", line 582, in __call__
    return self.current_callable(inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/_inductor/utils.py", line 2661, in run
    return model(new_inputs)
           ^^^^^^^^^^^^^^^^^
  File "/tmp/tmpl9cibdeh/jv/cjvt2dlgiwbtybi4lnwxvx47mx4g5jzyhquzxefp32ghvc3chpi4.py", line 590, in call
    triton_poi_fused_add_addcdiv_div_lerp_maximum_mul_neg_pow_reciprocal_rsub_sqrt_0.run(arg9_1, arg22_1, arg23_1, buf0, arg8_1, arg11_1, arg10_1, arg20_1.item(), arg21_1, arg0_1, arg13_1, arg25_1, arg12_1, arg15_1, arg14_1, arg2_1, arg8_1, arg9_1, arg11_1, arg0_1, arg12_1, arg13_1, arg15_1, arg2_1, 100, stream=stream0)
  File "/home/zong/code/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 1069, in run
    return launcher(
           ^^^^^^^^^
  File "<string>", line 5, in launcher
  File "/home/zong/code/pytorch/torch/_inductor/runtime/static_cuda_launcher.py", line 240, in run
    raise RuntimeError(
RuntimeError: Failed to launch kernel triton_poi_fused_add_addcdiv_div_lerp_maximum_mul_neg_pow_reciprocal_rsub_sqrt_0 with args 

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to get around this by changing the compile tests? How come it won't properly trace the updated step?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not able to change tests, it's acquire test data from optim_db which might change all optim related test and user also may use like this way.

@skipCUDAIf(not has_triton(), "torch.compile with cuda requires triton")
@skipXPUIf(not has_triton(), "torch.compile with xpu requires triton")
@optims(optim_db, dtypes=[torch.float32])
@parametrize("use_closure", [True, False])
def test_correctness(self, device, dtype, optim_info, use_closure):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)

I also curious about why _to_scalar not included in compiled code, need more time figure it out. :D

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mlazos do you remember if there's a place we'd need to update when _init_group updates?

@github-actions
Copy link
Contributor

github-actions bot commented Aug 1, 2025

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Aug 1, 2025
@fffrog fffrog removed the Stale label Aug 4, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Oct 3, 2025

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Oct 3, 2025
@janeyx99 janeyx99 removed the Stale label Oct 6, 2025
@janeyx99
Copy link
Contributor

janeyx99 commented Oct 6, 2025

@pytorchbot merge -r

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 6, 2025
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased opt/optim/sgd onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout opt/optim/sgd && git pull --rebase)

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Oct 6, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 3 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@janeyx99
Copy link
Contributor

janeyx99 commented Oct 6, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 6, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@janeyx99 janeyx99 added the suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) label Oct 6, 2025
@janeyx99
Copy link
Contributor

janeyx99 commented Oct 6, 2025

@pytorchbot merge

Suppressing BC linter as it is simply type widening.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@janeyx99
Copy link
Contributor

janeyx99 commented Oct 6, 2025

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: BC Lint / bc_linter, pull / linux-jammy-py3.10-clang12 / test (dynamo_wrapped, 3, 3, linux.2xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
Fixes pytorch#147921

## Changes

- Convert tensor `betas` using `_to_scalar`
- Change annotation of `betas` param
- Change param type in docs

## Test Result

```bash
pytest -s test/test_optim.py -k test_tensor_lr -vv
```

![image](https://github.com/user-attachments/assets/312ee045-1e8b-4789-aa6e-ba63e6df7e81)

![image](https://github.com/user-attachments/assets/7e6ec274-645b-46b9-b1a6-2b340a685203)

Pull Request resolved: pytorch#149939
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: optim suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Adam doesn't work with nonzero-dim Tensor betas

6 participants