Skip to content

Commit 510d0aa

Browse files
committed
Update on "Add prepend argument to nn.Module hooks"
cc ezyang gchanan [ghstack-poisoned]
2 parents ca3726c + 04b4aa6 commit 510d0aa

File tree

8 files changed

+23
-87
lines changed

8 files changed

+23
-87
lines changed

.github/ci_commit_pins/vision.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0d7807d59520289b2065b4db4a138b7fba2f61fd
1+
9c112935abe400222cca8f9fbc2d8386e0f25e80

aten/src/ATen/native/mps/operations/UnaryOps.mm

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -93,24 +93,6 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
9393
{ return mps::trunc_tensor(mpsGraph, inputTensor); });
9494
}
9595

96-
TORCH_IMPL_FUNC(signbit_out_mps) (const Tensor& self, const Tensor& output) {
97-
mps::unary_op(self, output, "signbit_out_mps",
98-
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
99-
MPSGraphTensor* output;
100-
// signbit is not implemented for int64 type.
101-
// workaround for `Function signbitOp_i64 was not found in the library`
102-
if ([inputTensor dataType] == MPSDataTypeInt64) {
103-
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
104-
output = [mpsGraph lessThanWithPrimaryTensor:inputTensor
105-
secondaryTensor:zeroTensor
106-
name:nil];
107-
} else {
108-
output = [mpsGraph signbitWithTensor: inputTensor name: nil];
109-
}
110-
return mps::castMPSTensor(mpsGraph, output, ScalarType::Bool);
111-
});
112-
}
113-
11496
TORCH_IMPL_FUNC(sign_out_mps) (const Tensor& self, const Tensor& output) {
11597
mps::unary_op(self, output, "sign_out_mps",
11698
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {

aten/src/ATen/native/native_functions.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8533,7 +8533,6 @@
85338533
dispatch:
85348534
CPU: signbit_out
85358535
CUDA: signbit_out
8536-
MPS: signbit_out_mps
85378536
SparseCPU, SparseCUDA: signbit_sparse_out
85388537
SparseCsrCPU, SparseCsrCUDA: signbit_sparse_csr_out
85398538

aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -234,18 +234,17 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_forward_nested(
234234
return std::make_tuple(Tensor(), Tensor());
235235
}
236236
}
237-
namespace{
237+
238238

239239
/**
240240
* This function is used to calculate two pieces of metadata that are needed
241241
* for use with flash-attention and efficient_attention kernels. They are the
242242
* cumulative sequence_length over a batch of sequences and the maximum sequence
243243
* length.
244244
*
245-
* @return A tuple of cumulative sequence lengths and the maximum sequence length,
246-
* and the last element in the cumulative_sequence_lengths
245+
* @return A tuple of cumulative sequence lengths and the maximum sequence length
247246
*/
248-
std::tuple<Tensor, int64_t, int64_t> cumulative_and_max_seq_len(Tensor qkv) {
247+
std::tuple<Tensor, int64_t> cumulative_and_max_seq_len(Tensor qkv) {
249248
TORCH_CHECK(
250249
qkv.is_nested(),
251250
"QKV must be nested for flash cumulative_seq_len calculation.")
@@ -275,7 +274,7 @@ std::tuple<Tensor, int64_t, int64_t> cumulative_and_max_seq_len(Tensor qkv) {
275274
// Send to GPU, this is pretty light weight calc for normal batch size
276275
// but maybe this needs to be on gpu
277276
cumulative_seqlen = cumulative_seqlen.to(TensorOptions().device(at::kCUDA));
278-
return std::tuple<Tensor, int64_t, int64_t>{cumulative_seqlen, max_seqlen, sum};
277+
return std::tuple<Tensor, int64_t>{cumulative_seqlen, max_seqlen};
279278
}
280279

281280
/**
@@ -338,7 +337,6 @@ bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) {
338337
return true;
339338
}
340339

341-
} // namespace
342340
std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked(
343341
const Tensor& query,
344342
const Tensor& key,
@@ -356,19 +354,19 @@ std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked(
356354
Tensor k_t = key.transpose(1, 2);
357355
Tensor v_t = value.transpose(1, 2);
358356

359-
auto cumulative_and_max_q_and_nnz_q = cumulative_and_max_seq_len(q_t);
360-
auto cumulative_and_max_k_and_nnz_k = cumulative_and_max_seq_len(k_t);
357+
auto cumulative_and_max_q = cumulative_and_max_seq_len(q_t);
358+
auto cumulative_and_max_k = cumulative_and_max_seq_len(k_t);
361359

362360
// K and V have to have the same Nnz, should probably torch_check
363361
// assume in order to not iterate over v
364362

365-
Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q_and_nnz_q);
366-
Tensor cumulative_sequence_length_k = std::get<0>(cumulative_and_max_k_and_nnz_k);
363+
Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q);
364+
Tensor cumulative_sequence_length_k = std::get<0>(cumulative_and_max_k);
367365

368-
const int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q_and_nnz_q);
366+
const int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q);
369367

370-
const int64_t Nnz_q = std::get<2>(cumulative_and_max_q_and_nnz_q);
371-
const int64_t Nnz_kv = std::get<2>(cumulative_and_max_k_and_nnz_k);
368+
const int64_t Nnz_q = cumulative_sequence_length_q[-1].item<int64_t>();
369+
const int64_t Nnz_kv = cumulative_sequence_length_k[-1].item<int64_t>();
372370

373371
Tensor query_buffer_reshaped;
374372
Tensor key_buffer_reshaped;
@@ -462,15 +460,15 @@ Tensor flash_attention_helper(
462460
int64_t head_dim{query.size(-1)};
463461
int64_t num_heads{query.size(-2)};
464462

465-
auto cumulative_and_max_q_and_nnz_q = cumulative_and_max_seq_len(query);
466-
Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q_and_nnz_q);
467-
int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q_and_nnz_q);
463+
auto cumulative_and_max_q = cumulative_and_max_seq_len(query);
464+
Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q);
465+
int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q);
468466

469467
TORCH_CHECK(
470468
key.is_same(key) && query.is_same(value),
471469
"Key and Value must be the same tensor");
472470

473-
int64_t Nnz_q = std::get<2>(cumulative_and_max_q_and_nnz_q);
471+
int64_t Nnz_q{cumulative_sequence_length_q[-1].item<int64_t>()};
474472

475473
// For the packed case we need to set the output size for dim 2 to 1
476474
auto atten_size = get_nested_size_tensor(query).clone();

test/inductor/test_smoke.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

test/test_fake_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def fn(
329329
self.assertTrue(isinstance(ten, FakeTensor))
330330
self.assertEqual(ten.device.type, 'cuda')
331331

332+
@skipIfRocm
332333
@unittest.skipIf(not RUN_CUDA, "requires cuda")
333334
def test_fallback_memory_prop(self):
334335
m = nn.Conv2d(16, 33, 3, stride=2, device="cuda", dtype=torch.half)

test/test_mps.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4175,20 +4175,6 @@ def helper(shape):
41754175

41764176
helper((2, 8, 4, 5))
41774177

4178-
def test_signbit(self):
4179-
def helper(shape, dtype):
4180-
cpu_x = torch.randn(shape, device='cpu').to(dtype)
4181-
x = cpu_x.clone().to('mps')
4182-
4183-
signbit_result = torch.signbit(x)
4184-
signbit_result_cpu = torch.signbit(cpu_x)
4185-
4186-
self.assertEqual(signbit_result, signbit_result_cpu)
4187-
4188-
helper((2, 8, 4, 5), torch.int)
4189-
helper((2, 8, 4, 5), torch.float)
4190-
helper((2, 8, 4, 5), torch.int64)
4191-
41924178
# Test neg
41934179
def test_neg(self):
41944180
def helper(shape):

torch/_C/__init__.pyi.in

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,8 @@ class Value:
527527

528528
# Defined in torch/csrc/jit/ir/ir.h
529529
class Block:
530-
def inputs(self) -> Iterator[Value]: ...
531-
def outputs(self) -> Iterator[Value]: ...
530+
def inputs(self) -> List[Value]: ...
531+
def outputs(self) -> List[Value]: ...
532532
def nodes(self) -> Iterator[Node]: ...
533533
def paramNode(self) -> Node: ...
534534
def returnNode(self) -> Node: ...
@@ -542,11 +542,11 @@ class Node:
542542
def __getitem__(self, key: str) -> Any: ...
543543
def schema(self) -> str: ...
544544
def input(self) -> Value: ...
545-
def inputs(self) -> Iterator[Value]: ...
545+
def inputs(self) -> List[Value]: ...
546546
def inputsAt(self, idx: _int) -> Value: ...
547547
def inputsSize(self) -> _int: ...
548548
def output(self) -> Value: ...
549-
def outputs(self) -> Iterator[Value]: ...
549+
def outputs(self) -> List[Value]: ...
550550
def outputsAt(self, idx: _int) -> Value: ...
551551
def outputsSize(self) -> _int: ...
552552
def hasMultipleOutputs(self) -> _bool: ...
@@ -622,8 +622,8 @@ class Node:
622622

623623
# Defined in torch/torch/csrc/jit/ir/ir.h
624624
class Graph:
625-
def inputs(self) -> Iterator[Value]: ...
626-
def outputs(self) -> Iterator[Value]: ...
625+
def inputs(self) -> List[Value]: ...
626+
def outputs(self) -> List[Value]: ...
627627
def nodes(self) -> Iterator[Node]: ...
628628
def param_node(self) -> Node: ...
629629
def return_node(self) -> Node: ...

0 commit comments

Comments
 (0)