forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathunet.py
More file actions
73 lines (61 loc) · 2.73 KB
/
unet.py
File metadata and controls
73 lines (61 loc) · 2.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from ..features.cuda_graph import CUDAGraph
class DSUNet(CUDAGraph, torch.nn.Module):
def __init__(self, unet, enable_cuda_graph=True):
super().__init__(enable_cuda_graph=enable_cuda_graph)
self.unet = unet
# SD pipeline accesses this attribute
self.in_channels = unet.in_channels
self.device = self.unet.device
self.dtype = self.unet.dtype
self.config = self.unet.config
self.fwd_count = 0
self.unet.requires_grad_(requires_grad=False)
self.unet.to(memory_format=torch.channels_last)
self.cuda_graph_created = False
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[k].copy_(kwargs[k])
self._cuda_graphs.replay()
return self.static_output
def forward(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
return outputs
else:
return self._forward(*inputs, **kwargs)
def _create_cuda_graph(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._forward(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._cuda_graphs = torch.cuda.CUDAGraph()
self.static_inputs = inputs
self.static_kwargs = kwargs
with torch.cuda.graph(self._cuda_graphs):
self.static_output = self._forward(*self.static_inputs, **self.static_kwargs)
self.cuda_graph_created = True
def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True, cross_attention_kwargs=None):
if cross_attention_kwargs:
return self.unet(sample,
timestamp,
encoder_hidden_states,
return_dict,
cross_attention_kwargs=cross_attention_kwargs)
else:
return self.unet(sample, timestamp, encoder_hidden_states, return_dict)