forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvae.py
More file actions
150 lines (124 loc) · 5.88 KB
/
vae.py
File metadata and controls
150 lines (124 loc) · 5.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from ..features.cuda_graph import CUDAGraph
class DSVAE(CUDAGraph, torch.nn.Module):
def __init__(self, vae, enable_cuda_graph=True):
super().__init__(enable_cuda_graph=enable_cuda_graph)
self.vae = vae
self.config = vae.config
self.device = self.vae.device
self.dtype = self.vae.dtype
self.vae.requires_grad_(requires_grad=False)
self.decoder_cuda_graph_created = False
self.encoder_cuda_graph_created = False
self.all_cuda_graph_created = False
def _graph_replay_decoder(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_decoder_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_decoder_kwargs[k].copy_(kwargs[k])
self._decoder_cuda_graph.replay()
return self.static_decoder_output
def _decode(self, x, return_dict=True):
return self.vae.decode(x, return_dict=return_dict)
def _create_cuda_graph_decoder(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._decode(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._decoder_cuda_graph = torch.cuda.CUDAGraph()
self.static_decoder_inputs = inputs
self.static_decoder_kwargs = kwargs
with torch.cuda.graph(self._decoder_cuda_graph):
self.static_decoder_output = self._decode(*self.static_decoder_inputs, **self.static_decoder_kwargs)
self.decoder_cuda_graph_created = True
def decode(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.decoder_cuda_graph_created:
outputs = self._graph_replay_decoder(*inputs, **kwargs)
else:
self._create_cuda_graph_decoder(*inputs, **kwargs)
outputs = self._graph_replay_decoder(*inputs, **kwargs)
return outputs
else:
return self._decode(*inputs, **kwargs)
def _graph_replay_encoder(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_encoder_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_encoder_kwargs[k].copy_(kwargs[k])
self._encoder_cuda_graph.replay()
return self.static_encoder_output
def _encode(self, x, return_dict=True):
return self.vae.encode(x, return_dict=return_dict)
def _create_cuda_graph_encoder(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._encode(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._encoder_cuda_graph = torch.cuda.CUDAGraph()
self.static_encoder_inputs = inputs
self.static_encoder_kwargs = kwargs
with torch.cuda.graph(self._encoder_cuda_graph):
self.static_encoder_output = self._encode(*self.static_encoder_inputs, **self.static_encoder_kwargs)
self.encoder_cuda_graph_created = True
def encode(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.encoder_cuda_graph_created:
outputs = self._graph_replay_encoder(*inputs, **kwargs)
else:
self._create_cuda_graph_encoder(*inputs, **kwargs)
outputs = self._graph_replay_encoder(*inputs, **kwargs)
return outputs
else:
return self._encode(*inputs, **kwargs)
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[k].copy_(kwargs[k])
self._all_cuda_graph.replay()
return self.static_output
def forward(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
return outputs
else:
return self._forward(*inputs, **kwargs)
def _create_cuda_graph(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._forward(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._all_cuda_graph = torch.cuda.CUDAGraph()
self.static_inputs = inputs
self.static_kwargs = kwargs
with torch.cuda.graph(self._all_cuda_graph):
self.static_output = self._forward(*self.static_inputs, **self.static_kwargs)
self.all_cuda_graph_created = True
def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True):
return self.vae(sample, timestamp, encoder_hidden_states, return_dict)