forked from JarodMica/rvc-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi.py
More file actions
134 lines (112 loc) · 4.56 KB
/
api.py
File metadata and controls
134 lines (112 loc) · 4.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, JSONResponse
from loguru import logger
from pydantic import BaseModel
import tempfile
import base64
import shutil
import zipfile
import os
class SetDeviceRequest(BaseModel):
device: str
class ConvertAudioRequest(BaseModel):
audio_data: str
class SetParamsRequest(BaseModel):
params: dict
class SetModelsDirRequest(BaseModel):
models_dir: str
def setup_routes(app: FastAPI):
@app.post("/convert")
def rvc_convert(request: ConvertAudioRequest):
if not app.state.rvc.current_model:
raise HTTPException(status_code=400, detail="No model loaded. Please load a model first.")
tmp_input = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
tmp_output = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
try:
logger.info("Received request to convert audio")
audio_data = base64.b64decode(request.audio_data)
tmp_input.write(audio_data)
input_path = tmp_input.name
output_path = tmp_output.name
app.state.rvc.infer_file(input_path, output_path)
output_data = tmp_output.read()
return Response(content=output_data, media_type="audio/wav")
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
finally:
tmp_input.close()
tmp_output.close()
os.unlink(tmp_input.name)
os.unlink(tmp_output.name)
@app.get("/models")
def list_models():
return JSONResponse(content={"models": app.state.rvc.list_models()})
@app.post("/models/{model_name}")
def load_model(model_name: str):
try:
app.state.rvc.load_model(model_name)
return JSONResponse(content={"message": f"Model {model_name} loaded successfully"})
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/params")
def get_params():
return JSONResponse(content={
"f0method": app.state.rvc.f0method,
"f0up_key": app.state.rvc.f0up_key,
"index_rate": app.state.rvc.index_rate,
"filter_radius": app.state.rvc.filter_radius,
"resample_sr": app.state.rvc.resample_sr,
"rms_mix_rate": app.state.rvc.rms_mix_rate,
"protect": app.state.rvc.protect
})
@app.post("/params")
def set_params(request: SetParamsRequest):
try:
app.state.rvc.set_params(**request.params)
return JSONResponse(content={"message": "Parameters updated successfully"})
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/upload_model")
async def upload_models(file: UploadFile = File(...)):
try:
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
shutil.copyfileobj(file.file, tmp_file)
with zipfile.ZipFile(tmp_file.name, 'r') as zip_ref:
zip_ref.extractall(app.state.rvc.models_dir)
os.unlink(tmp_file.name)
# Update the list of models after upload
app.state.rvc.models = app.state.rvc._load_available_models()
return JSONResponse(content={"message": "Models uploaded and extracted successfully"})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/set_device")
def set_device(request: SetDeviceRequest):
try:
device = request.device
app.state.rvc.set_device(device)
return JSONResponse(content={"message": f"Device set to {device}"})
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/set_models_dir")
def set_models_dir(request: SetModelsDirRequest):
try:
new_models_dir = request.models_dir
app.state.rvc.set_models_dir(new_models_dir)
return JSONResponse(content={"message": f"Models directory set to {new_models_dir}"})
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
def create_app():
app = FastAPI()
# Add CORS middleware
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
setup_routes(app)
return app