|
129 | 129 | return result; |
130 | 130 | } |
131 | 131 |
|
| 132 | +Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step, Tensor& result) { |
| 133 | + AT_DISPATCH_MPS_TYPES(result.scalar_type(), "arange_mps", [&]() { |
| 134 | + using accscalar_t = at::acc_type<scalar_t, true>; |
| 135 | + auto xstart = start.to<accscalar_t>(); |
| 136 | + auto xend = end.to<accscalar_t>(); |
| 137 | + auto xstep = step.to<accscalar_t>(); |
| 138 | + |
| 139 | + // double size_d = ((xend - xstart) / xstep) + 1; |
| 140 | + double size_d; |
| 141 | + if (std::is_same<scalar_t, int64_t>::value) { |
| 142 | + size_d = static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>()) |
| 143 | + / step.to<accscalar_t>() + 1; |
| 144 | + } else { |
| 145 | + size_d = static_cast<double>(end.to<double>() - start.to<double>()) |
| 146 | + / step.to<double>() + 1; |
| 147 | + } |
| 148 | + |
| 149 | + TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); |
| 150 | + TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) && |
| 151 | + std::isfinite(static_cast<double>(xend)), |
| 152 | + "unsupported range: ", xstart, " -> ", xend); |
| 153 | + TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), |
| 154 | + "upper bound and larger bound inconsistent with step sign"); |
| 155 | + |
| 156 | + TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()), |
| 157 | + "invalid size, possible overflow?"); |
| 158 | + |
| 159 | + int64_t size = static_cast<int64_t>(size_d); |
| 160 | + |
| 161 | + int64_t numel = result.numel(); |
| 162 | + |
| 163 | + if (numel != size) { |
| 164 | + result.resize_({size}); |
| 165 | + } |
| 166 | + bool is_contiguous = result.is_contiguous(); |
| 167 | + Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result; |
| 168 | + using namespace mps; |
| 169 | + auto cache_ = MPSGraphCache::getInstance(); |
| 170 | + auto stream = getCurrentMPSStream(); |
| 171 | + auto mpsDataType = getMPSDataType(result.scalar_type()); |
| 172 | + @autoreleasepool { |
| 173 | + string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size); |
| 174 | + auto cachedGraph = static_cast<RangeCachedGraph *>(cache_->LookUp(key)); |
| 175 | + if (!cachedGraph) { |
| 176 | + auto *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph *() { |
| 177 | + auto mpsGraph = make_mps_graph(); |
| 178 | + return new RangeCachedGraph(mpsGraph, mpsDataType, size); |
| 179 | + }); |
| 180 | + cachedGraph = static_cast<RangeCachedGraph *>(tmpCachedGraph); |
| 181 | + } |
| 182 | + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, r); |
| 183 | + NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; |
| 184 | + MPSScalar startScalar = getMPSScalar(start, result.scalar_type()); |
| 185 | + feeds[cachedGraph->startTensor] = getMPSGraphTensorFromScalar(stream, startScalar); |
| 186 | + MPSScalar stepScalar = getMPSScalar(step, result.scalar_type()); |
| 187 | + feeds[cachedGraph->multiplyTensor] = getMPSGraphTensorFromScalar(stream, stepScalar); |
| 188 | + |
| 189 | + NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{ |
| 190 | + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() |
| 191 | + }; |
| 192 | + runMPSGraph(stream, cachedGraph->graph(), feeds, results); |
| 193 | + } |
| 194 | + |
| 195 | + if(!is_contiguous) { |
| 196 | + result.copy_(r); |
| 197 | + } |
| 198 | + }); |
| 199 | + |
| 200 | + return result; |
| 201 | +} |
| 202 | + |
132 | 203 | Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps, Tensor& result) { |
133 | 204 | using namespace mps; |
134 | 205 |
|
|
0 commit comments