Skip to content

Conversation

@milesial
Copy link
Contributor

@milesial milesial commented Dec 26, 2022

Adds _foreach_clamp_min and _foreach_clamp_max as binary ops, with scalar, scalarlist and tensorlist support.

Timing example for _foreach_clamp_min_ on a GTX3070Ti across a list of tensors with varying count and item size (times are in microseconds (us)):

CUDA:

[------------------ (tensors, scalar) -------------------]
                                   |  for loop  |  foreach
      10 tensors of size 4         |     29.0   |     10.2
      100 tensors of size 4        |    234.4   |     18.3
      1000 tensors of size 4       |   2194.1   |    113.5
      10000 tensors of size 4      |  21745.6   |   1144.5
      10 tensors of size 16        |     29.5   |     12.0
      100 tensors of size 16       |    256.9   |     19.9
      1000 tensors of size 16      |   2499.7   |    123.6
      10000 tensors of size 16     |  25022.2   |   1295.6
      10 tensors of size 256       |     32.8   |     11.2
      100 tensors of size 256      |    258.8   |     19.7
      1000 tensors of size 256     |   2509.2   |    123.7
      10000 tensors of size 256    |  25016.2   |   1295.4
      10 tensors of size 65536     |     32.9   |     18.7
      100 tensors of size 65536    |    327.1   |    150.3
      1000 tensors of size 65536   |   3051.3   |   1388.0
      10000 tensors of size 65536  |  30476.9   |  14021.5

[------------------ (tensors, tensors) ------------------]
                                   |  for loop  |  foreach
      10 tensors of size 4         |     26.8   |     17.3
      100 tensors of size 4        |    206.8   |     90.5
      1000 tensors of size 4       |   1993.0   |    828.9
      10000 tensors of size 4      |  19851.0   |   9063.3
      10 tensors of size 16        |     34.7   |     20.0
      100 tensors of size 16       |    232.2   |    102.1
      1000 tensors of size 16      |   2220.9   |    977.3
      10000 tensors of size 16     |  22644.5   |  10361.4
      10 tensors of size 256       |     30.5   |     19.7
      100 tensors of size 256      |    231.6   |    102.4
      1000 tensors of size 256     |   2251.9   |    978.7
      10000 tensors of size 256    |  22680.3   |  10405.8
      10 tensors of size 65536     |     30.6   |     34.4
      100 tensors of size 65536    |    315.1   |    223.6
      1000 tensors of size 65536   |   3252.1   |   2114.4
      10000 tensors of size 65536  |  30578.0   |  22826.3

CPU:

[------------------- (tensors, scalar) -------------------]
                                   |  for loop  |  foreach 
      10 tensors of size 4         |      13.0  |       9.6
      100 tensors of size 4        |      62.4  |      31.6
      1000 tensors of size 4       |     562.2  |     245.6
      10000 tensors of size 4      |    5552.2  |    2517.7
      10 tensors of size 16        |      14.9  |      11.3
      100 tensors of size 16       |      74.1  |      36.9
      1000 tensors of size 16      |     663.7  |     285.5
      10000 tensors of size 16     |    6765.2  |    2947.5
      10 tensors of size 256       |      15.2  |      11.8
      100 tensors of size 256      |      76.0  |      37.7
      1000 tensors of size 256     |     728.8  |     323.9
      10000 tensors of size 256    |    7274.4  |    3800.3
      10 tensors of size 65536     |     105.6  |     124.5
      100 tensors of size 65536    |     982.8  |     939.7
      1000 tensors of size 65536   |   14993.1  |   14579.2
      10000 tensors of size 65536  |  163091.0  |  151555.8

[------------------- (tensors, tensors) ------------------]
                                   |  for loop  |  foreach 
      10 tensors of size 4         |      11.8  |      10.5
      100 tensors of size 4        |      53.1  |      38.2
      1000 tensors of size 4       |     465.1  |     316.1
      10000 tensors of size 4      |    4616.9  |    3625.9
      10 tensors of size 16        |      13.5  |      12.3
      100 tensors of size 16       |      63.0  |      46.5
      1000 tensors of size 16      |     560.1  |     359.9
      10000 tensors of size 16     |    5586.8  |    3765.9
      10 tensors of size 256       |      15.2  |      13.7
      100 tensors of size 256      |      64.4  |      48.3
      1000 tensors of size 256     |     653.7  |     410.0
      10000 tensors of size 256    |    5916.6  |    3901.3
      10 tensors of size 65536     |     109.1  |     106.8
      100 tensors of size 65536    |    1128.9  |    1105.0
      1000 tensors of size 65536   |   16245.0  |   15950.8
      10000 tensors of size 65536  |  171111.3  |  163540.2

Example use:

tensors = [torch.randn(16, device='cuda') for _ in range(10)]

out = torch._foreach_clamp_min(tensors, 0.1)
out = torch._foreach_clamp_min(tensors, [0.1] * len(tensors))
out = torch._foreach_clamp_min(tensors, tensors)
torch._foreach_clamp_min_(tensors, 0.1)
torch._foreach_clamp_min_(tensors, [0.1] * len(tensors))
torch._foreach_clamp_min_(tensors, tensors)

Does not support complex types.
Changes the existing foreach_minimum/maximum to use this new implementation.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @Guobing-Chen @chunyuan-w @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

@pytorch-bot pytorch-bot bot added the release notes: foreach_frontend release notes category label Dec 26, 2022
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Dec 26, 2022

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: milesial (083daebed4253a59098c5dda579785ba0f659f7b)

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 26, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91384

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit f7b450f:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@milesial milesial marked this pull request as draft December 26, 2022 16:24
@vadimkantorov
Copy link
Contributor

@cpuhrsch With proliferation of foreach methods, worth adding TensorList (as accepted by foreach methods) as some sort of NestedTensor or a companion TensorList first-order structure?

@milesial
Copy link
Contributor Author

I'm not happy with the duplication of this piece:

template <typename T>
struct clamp_min {
    __device__ T operator()(const T& a, const T& b) const { return _isnan(a) or a > b ? a : b; }
};

template <typename T>
struct clamp_max {
    __device__ T operator()(const T& a, const T& b) const { return _isnan(a) or a < b ? a : b; }
};

Where can I put it so that I include it in the three .cu files?

@milesial milesial marked this pull request as ready for review December 27, 2022 00:22
@ngimel
Copy link
Collaborator

ngimel commented Dec 27, 2022

Where can I put it so that I include it in the three .cu files?

Just create a new header file if none of the existing headers seem appropriate? aten/native/cuda folder contains a few headers, so one more is not a problem.

@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 27, 2022
@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Dec 31, 2022
@zou3519 zou3519 removed their request for review January 3, 2023 16:04
@milesial
Copy link
Contributor Author

milesial commented Jan 4, 2023

@ngimel ready for final review.

In the process of fixing tests I added bool support to the regular clamp forward CUDA, and bool+float16 for CPU. Also expanded the nan/inf test to all foreach binary ops.

windows and multiprocessing test failures unrelated.

@milesial milesial requested review from ngimel and removed request for mruberry January 5, 2023 18:01
@milesial milesial requested a review from ngimel January 6, 2023 08:39
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great @milesial let's see what CI says

@ngimel
Copy link
Collaborator

ngimel commented Jan 6, 2023

Test failure looks unrelated

@ngimel
Copy link
Collaborator

ngimel commented Jan 6, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 6, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: The following mandatory check(s) failed (Rule superuser):

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

@milesial
Copy link
Contributor Author

milesial commented Jan 9, 2023

@ngimel merge failed :/

@ngimel
Copy link
Collaborator

ngimel commented Jan 9, 2023

@pytorchbot merge -f "test failure flaky"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@DanilBaibak
Copy link
Contributor

@pytorchbot revert -m "Break internal build" -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

Reverting PR 91384 failed

Reason: Command git -C /home/runner/work/pytorch/pytorch revert --no-edit 9d20d6d5ec5c0ac5ff00e4967f480f07ba0bb2bf returned non-zero exit code 1

Auto-merging aten/src/ATen/native/ForeachOpsKernels.cpp
CONFLICT (content): Merge conflict in aten/src/ATen/native/ForeachOpsKernels.cpp
Auto-merging aten/src/ATen/native/native_functions.yaml
Auto-merging test/test_foreach.py
CONFLICT (content): Merge conflict in test/test_foreach.py
Auto-merging torch/testing/_internal/common_methods_invocations.py
CONFLICT (content): Merge conflict in torch/testing/_internal/common_methods_invocations.py
error: could not revert 9d20d6d5ec... Foreach clamp_min clamp_max (#91384)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git revert --continue".
hint: You can instead skip this commit with "git revert --skip".
hint: To abort and get back to the state before "git revert",
hint: run "git revert --abort".
Details for Dev Infra team Raised by workflow job

@DanilBaibak
Copy link
Contributor

@pytorchbot revert -m "Break internal build" -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

Reverting PR 91384 failed

Reason: Command git -C /home/runner/work/pytorch/pytorch revert --no-edit 9d20d6d5ec5c0ac5ff00e4967f480f07ba0bb2bf returned non-zero exit code 1

Auto-merging aten/src/ATen/native/ForeachOpsKernels.cpp
CONFLICT (content): Merge conflict in aten/src/ATen/native/ForeachOpsKernels.cpp
Auto-merging aten/src/ATen/native/native_functions.yaml
Auto-merging test/test_foreach.py
CONFLICT (content): Merge conflict in test/test_foreach.py
Auto-merging torch/testing/_internal/common_methods_invocations.py
CONFLICT (content): Merge conflict in torch/testing/_internal/common_methods_invocations.py
error: could not revert 9d20d6d5ec... Foreach clamp_min clamp_max (#91384)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git revert --continue".
hint: You can instead skip this commit with "git revert --skip".
hint: To abort and get back to the state before "git revert",
hint: run "git revert --abort".
Details for Dev Infra team Raised by workflow job

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: inductor open source release notes: foreach_frontend release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants