-
Notifications
You must be signed in to change notification settings - Fork 37
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·1928 lines (1695 loc) · 77.9 KB
/
Copy pathtrain.py
File metadata and controls
executable file
·1928 lines (1695 loc) · 77.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import os
import re
import sys
import glob
import json
import gc
import torch
import logging
import argparse
import datetime
import warnings
import torch.distributed as dist
import torch.multiprocessing as mp
# ── FIX: Ensure project root is in sys.path BEFORE any arvc imports ──
# When run as a subprocess (e.g. python arvc/engine/training/runner/train.py),
# the project root isn't automatically in the path, causing ModuleNotFoundError.
_project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
if _project_root not in sys.path:
sys.path.insert(0, _project_root)
os.environ["USE_LIBUV"] = "0" if sys.platform == "win32" else "1"
from tqdm import tqdm
from collections import deque
from contextlib import nullcontext
from random import randint, shuffle
from arvc.utils import strtobool
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
from time import time as ttime
from torch.nn.parallel import DistributedDataParallel as DDP
from arvc.engine.models.utils import clear_gpu_cache
from arvc.engine.models.backends import directml, opencl, zluda
# XPU backend — may not exist in all installations; graceful fallback
try:
from arvc.engine.models.backends import xpu
except ImportError:
xpu = None
# ZLUDA detection: True when running on AMD GPU via CUDA compatibility layer
_is_zluda = zluda.is_available()
from arvc.utils.variables import logger, translations as _raw_translations
# ── BULLETPROOF SAFETY NET ──
# Wrap translations in a dict subclass that returns the key name itself
# if missing. This means training will NEVER crash with KeyError on
# translation lookups, regardless of what the language files contain.
class _SafeTranslations(dict):
def __missing__(self, key):
return key
def __contains__(self, key):
return True # always report True so `in` checks never fail
translations = _SafeTranslations(_raw_translations)
from arvc.engine.models.algorithms import commons
from arvc.engine.training.runner import losses
from arvc.engine.training.runner.extract_model import extract_model
from arvc.engine.training.runner.mel_processing import (
MultiScaleMelSpectrogramLoss,
mel_spectrogram_torch,
spec_to_mel_torch
)
from arvc.engine.training.runner.utils import (
HParams,
summarize,
load_checkpoint,
save_checkpoint,
load_wav_to_torch,
latest_checkpoint_path,
plot_spectrogram_to_numpy,
)
from arvc.engine.models.weight_norm import configure_weight_norm, use_new_pytorch
from arvc.utils.variables import config as main_config
from arvc.utils.variables import configs as main_configs
from arvc.utils.huggingface import HF_download_file
if not getattr(main_config, 'debug_mode', False):
warnings.filterwarnings("ignore")
logging.getLogger("torch").setLevel(logging.ERROR)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--train", action='store_true')
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--rvc_version", type=str, default="v2")
parser.add_argument("--save_every_epoch", type=int, required=True)
parser.add_argument("--save_only_latest", type=lambda x: bool(strtobool(x)), default=True)
parser.add_argument("--save_every_weights", type=lambda x: bool(strtobool(x)), default=True)
parser.add_argument("--total_epoch", type=int, default=300)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--gpu", type=str, default="0")
parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
parser.add_argument("--g_pretrained_path", type=str, default="")
parser.add_argument("--d_pretrained_path", type=str, default="")
parser.add_argument("--overtraining_detector", type=lambda x: bool(strtobool(x)), default=False)
parser.add_argument("--overtraining_threshold", type=int, default=50)
parser.add_argument("--cleanup", type=lambda x: bool(strtobool(x)), default=False)
parser.add_argument("--cache_data_in_gpu", type=lambda x: bool(strtobool(x)), default=False)
parser.add_argument("--model_author", type=str)
parser.add_argument("--vocoder", type=str, default="Default")
parser.add_argument("--checkpointing", type=lambda x: bool(strtobool(x)), default=False)
parser.add_argument("--deterministic", type=lambda x: bool(strtobool(x)), default=False)
parser.add_argument("--benchmark", type=lambda x: bool(strtobool(x)), default=True)
parser.add_argument("--optimizer", type=str, default="AdamW")
parser.add_argument("--energy_use", type=lambda x: bool(strtobool(x)), default=False)
parser.add_argument("--use_custom_reference", type=lambda x: bool(strtobool(x)), default=False)
parser.add_argument("--reference_path", type=str, default="")
parser.add_argument("--multiscale_mel_loss", type=lambda x: bool(strtobool(x)), default=True)
parser.add_argument("--use_cosine_annealing_lr", type=lambda x: bool(strtobool(x)), default=False)
parser.add_argument("--architecture", type=str, default="RVC", help="Model architecture: RVC or SVC")
parser.add_argument("--compile_model", type=lambda x: bool(strtobool(x)), default=False, help="Use torch.compile() on generator for PyTorch 2.x speedup")
parser.add_argument("--use_8bit_adam", type=lambda x: bool(strtobool(x)), default=False, help="Use 8-bit Adam optimizer for lower VRAM (requires bitsandbytes)")
parser.add_argument("--grad_accum_steps", type=int, default=1, help="Gradient accumulation steps (reduces VRAM usage with larger effective batch sizes)")
parser.add_argument("--newpytorch", type=lambda x: bool(strtobool(x)), default=True, help="Use PyTorch 2.0+ parametrization format (default, matches Applio/VRVC). Set false for legacy weight_norm format.")
parser.add_argument(
"--fast_train",
type=lambda x: bool(strtobool(x)),
default=False,
help="Vocal-quality-safe training speedup bundle. Enables TF32 matmul+cuDNN (Ampere+), "
"larger dataloader prefetch, higher worker count, auto torch.compile, and reduces tqdm "
"update overhead. Targets ~3x faster training with NO loss in vocal fidelity — only I/O "
"and kernel-fusion optimizations are applied, never numerical changes.",
)
parser.add_argument(
"--bf16_adamw",
type=lambda x: bool(strtobool(x)),
default=False,
help="Applio-parity shortcut: use AnyPrecisionAdamW with bf16 training precision. "
"Equivalent to passing --optimizer=AnyPrecisionAdamW and setting brain=True "
"in assets/config.json. Recommended on Ampere+ GPUs (RTX 30xx/40xx/A100/H100).",
)
return parser.parse_args()
d_lr_coeff = 1.0
g_lr_coeff = 1.0
d_step_per_g_step = 1
randomized = True # Applio-style: random slice for training, full-sequence for finetuning
args = parse_arguments()
(
model_name,
save_every_epoch,
total_epoch,
pretrainG,
pretrainD,
version,
gpus,
batch_size,
pitch_guidance,
save_only_latest,
save_every_weights,
cache_data_in_gpu,
overtraining_detector,
overtraining_threshold,
cleanup,
model_author,
vocoder,
checkpointing,
optimizer_choice,
energy_use,
use_custom_reference,
reference_path,
multiscale_mel_loss,
use_cosine_annealing_lr,
architecture,
compile_model,
use_8bit_adam,
grad_accum_steps,
newpytorch,
fast_train,
bf16_adamw,
) = (
args.model_name,
args.save_every_epoch,
args.total_epoch,
args.g_pretrained_path,
args.d_pretrained_path,
args.rvc_version,
args.gpu,
args.batch_size,
args.pitch_guidance,
args.save_only_latest,
args.save_every_weights,
args.cache_data_in_gpu,
args.overtraining_detector,
args.overtraining_threshold,
args.cleanup,
args.model_author,
args.vocoder,
args.checkpointing,
args.optimizer,
args.energy_use,
args.use_custom_reference,
args.reference_path,
args.multiscale_mel_loss,
args.use_cosine_annealing_lr,
args.architecture,
args.compile_model,
args.use_8bit_adam,
args.grad_accum_steps,
args.newpytorch,
args.fast_train,
args.bf16_adamw,
)
# ── FAST-TRAIN BUNDLE: vocal-quality-safe 3x speedup ─────────────────────────
# These knobs do NOT change numerics — they only affect kernel selection,
# matmul precision on Ampere+ GPUs, I/O pipelining, and UI overhead. Vocal
# fidelity is preserved bit-for-bit because no loss function, gradient path,
# or model weight is touched.
if fast_train and torch.cuda.is_available() and not _is_zluda:
# 1. TF32 matmul — 2-3x faster on RTX 30xx/40xx/A100. TF32 uses 10-bit
# mantissa (vs FP32's 23-bit) which is well below the audible noise
# floor for vocal training. This is the single biggest speedup lever.
try:
torch.set_float32_matmul_precision("high") # 'high' == TF32
except Exception:
pass
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# 2. cuDNN benchmark — picks the fastest conv kernel for the current
# input shape. Tiny warmup cost, big sustained speedup. Already on
# via --benchmark=True but we force it on under fast_train.
torch.backends.cudnn.benchmark = True
# 3. cuDNN deterministic OFF (only relevant if --deterministic is also
# passed). Fast train wants the non-deterministic kernel picker.
if not deterministic:
torch.backends.cudnn.deterministic = False
# 4. CUDA allocator config — 'expandable_segments' avoids fragmentation
# on long runs and lets the allocator return memory to the pool more
# aggressively. Reduces OOM-induced CUDA cache resets that cost ~1-2s
# each. Net win for sustained training throughput.
try:
os.environ.setdefault(
"PYTORCH_CUDA_ALLOC_CONF",
"expandable_segments:True,"
"max_split_size_mb:512",
)
except Exception:
pass
# 5. Auto-enable torch.compile unless the user explicitly disabled it.
# mode="reduce-overhead" fuses kernels and uses CUDA graphs — same
# math, ~1.3-2x faster.
if not compile_model and hasattr(torch, "compile"):
compile_model = True
# 6. SPEED PATCH: bf16 path on Ampere+ GPUs. bf16 has the same exponent
# range as fp32 (no overflow risk like fp16), and on Ampere/Hopper
# hardware bf16 matmul is ~2x faster than fp32. Combined with the
# AnyPrecisionAdamW optimizer (which keeps fp32 master weights), this
# is the single biggest "free" speedup for vocal training. The user
# does NOT lose fidelity — bf16's 8-bit mantissa is well below the
# audible noise floor for a vocal model.
if bf16_adamw and not getattr(main_config, 'brain', False):
# Flip the global flag so the rest of train.py picks up bf16
# autocast + AnyPrecisionAdamW automatically.
try:
main_config.brain = True
if __name__ == "__main__":
print("[Advanced-RVC] FAST-TRAIN: bf16_adamw auto-enabled brain=True for AnyPrecisionAdamW + bf16 autocast.")
except Exception:
pass
if __name__ == "__main__":
print("[Advanced-RVC] FAST-TRAIN: TF32 + cuDNN benchmark + torch.compile + expandable_segments enabled (vocal-quality-safe).")
# ── Configure weight_norm mode BEFORE any model creation ──
configure_weight_norm(newpytorch)
if newpytorch:
if __name__ == "__main__": print(f"[Advanced-RVC] PyTorch weight format: NEW (2.0+ parametrization)")
else:
if __name__ == "__main__": print(f"[Advanced-RVC] PyTorch weight format: OLD (weight_norm, RVC fork compatible)")
# Discriminator version: use v3 discriminator for BigVGAN and RefineGAN (matches VRVC)
disc_version = version if vocoder not in ["RefineGAN", "BigVGAN"] else "v3"
# is_half logic — matches Vietnamese-RVC exactly
is_half = main_config.is_half
if getattr(main_config, 'brain', False): is_half = True
# SVC architecture overrides (from Vietnamese-RVC)
if architecture == "SVC":
disc_version = version if vocoder != "Default" else "v0"
pitch_guidance = True
energy_use = False
# Vietnamese-RVC style experiment_dir / checkpoint_path handling
weights_path = main_configs["weights_path"]
logs_path = main_configs["logs_path"]
custom_save_checkpoint_path = None
if not os.path.exists(model_name):
experiment_dir = os.path.join(logs_path, model_name)
else:
experiment_dir = model_name
model_name = os.path.basename(model_name)
custom_save_checkpoint_path = weights_path
checkpoint_path = experiment_dir if custom_save_checkpoint_path is None else custom_save_checkpoint_path
training_file_path = os.path.join(experiment_dir, "training_data.json")
config_save_path = os.path.join(experiment_dir, "config.json")
filelist_path = os.path.join(experiment_dir, "filelist.txt")
eval_dir = os.path.join(experiment_dir, "eval")
spec_dirs = None
save_the_pid = True
cache_spectrogram = True
use_clip_grad_value = False
# Create config.json if it doesn't exist
if not os.path.exists(config_save_path):
import shutil
os.makedirs(experiment_dir, exist_ok=True)
sr = 32000 # default sample rate
extracted_dir = os.path.join(experiment_dir, f"{version}_extracted")
if os.path.exists(extracted_dir):
wav_files = glob.glob(os.path.join(extracted_dir, "*.wav"))
if wav_files:
try:
import soundfile as sf
_, detected_sr = sf.read(wav_files[0])
sr = detected_sr
except:
pass
config_template_path = os.path.join(main_configs["configs_path"], version, f"{sr}.json")
if not os.path.exists(config_template_path):
# Try nearest available sample rate as fallback
for fallback_sr in [40000, 32000, 48000, 24000, 44100]:
config_template_path = os.path.join(main_configs["configs_path"], version, f"{fallback_sr}.json")
if os.path.exists(config_template_path):
break
if os.path.exists(config_template_path):
shutil.copy(config_template_path, config_save_path)
else:
raise FileNotFoundError(f"Config template not found at: {config_template_path}")
# cuDNN / TF32 — controlled by config, not forced ON
torch.backends.cudnn.deterministic = args.deterministic if not main_config.device.startswith(("ocl", "privateuseone")) and not _is_zluda else False
torch.backends.cudnn.benchmark = args.benchmark if not main_config.device.startswith(("ocl", "privateuseone")) and not _is_zluda else False
tf32_enabled = getattr(main_config, 'tf32', False)
if torch.cuda.is_available() and not _is_zluda:
torch.backends.cuda.matmul.allow_tf32 = tf32_enabled
torch.backends.cudnn.allow_tf32 = tf32_enabled
lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
global_step, last_loss_gen_all, overtrain_save_epoch = 0, 0, 0
loss_gen_history, smoothed_loss_gen_history, loss_disc_history, smoothed_loss_disc_history = [], [], [], []
consecutive_increases_gen = 0
consecutive_increases_disc = 0
avg_losses = {
"grad_d_50": deque(maxlen=50),
"grad_g_50": deque(maxlen=50),
"disc_loss_50": deque(maxlen=50),
"adv_loss_50": deque(maxlen=50),
"fm_loss_50": deque(maxlen=50),
"kl_loss_50": deque(maxlen=50),
"mel_loss_50": deque(maxlen=50),
"gen_loss_50": deque(maxlen=50),
"energy_loss_50": deque(maxlen=50),
}
with open(config_save_path, "r", encoding="utf-8") as f:
config = json.load(f)
config = HParams(**config)
config.data.training_files = filelist_path
def main():
global training_file_path, last_loss_gen_all, smoothed_loss_gen_history, loss_gen_history, loss_disc_history, smoothed_loss_disc_history, overtrain_save_epoch, model_author, vocoder, checkpointing, gpus, energy_use
log_data = {
translations["modelname"]: model_name,
translations["save_every_epoch"]: save_every_epoch,
translations["total_e"]: total_epoch,
translations["dorg"].format(pretrainG=pretrainG, pretrainD=pretrainD): "",
translations["training_version"]: version,
"Gpu": gpus,
translations["batch_size"]: batch_size,
translations["training_f0"]: pitch_guidance,
translations["save_only_latest"]: save_only_latest,
translations["save_every_weights"]: save_every_weights,
translations["cache_in_gpu"]: cache_data_in_gpu,
translations["overtraining_detector"]: overtraining_detector,
translations["threshold"]: overtraining_threshold,
translations["cleanup_training"]: cleanup,
translations["memory_efficient_training"]: checkpointing,
translations["optimizer"]: optimizer_choice,
translations["train&energy"]: energy_use,
translations["multiscale_mel_loss"]: multiscale_mel_loss,
translations["cosine_annealing_lr"]: use_cosine_annealing_lr,
translations["architecture"]: architecture,
}
if model_author: log_data[translations["model_author"].format(model_author=model_author)] = ""
if vocoder != "Default": log_data[translations['vocoder']] = vocoder
for key, value in log_data.items():
logger.debug(f"{key}: {value}" if value != "" else f"{key} {value}")
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
wavs = glob.glob(os.path.join(os.path.join(experiment_dir, "sliced_audios"), "*.wav"))
if wavs:
_, sr = load_wav_to_torch(wavs[0])
if sr != config.data.sample_rate:
logger.warning(translations["training_sr"].format(sr_1=config.data.sample_rate, sr_2=sr))
sys.exit(1)
else:
logger.warning(translations["not_found_dataset"])
sys.exit(1)
# Device selection — Vietnamese-RVC style with XPU + CPU fallback + Advanced-RVC ZLUDA
if gpus == "-":
device, gpus = torch.device("cpu"), [0]
n_gpus = 1
logger.warning(translations["not_gpu"])
elif torch.cuda.is_available() and main_config.device.startswith("cuda"):
if _is_zluda:
device = torch.device("cuda")
gpus = [0]
n_gpus = 1
logger.info("ZLUDA detected (AMD GPU) — using single GPU mode with gloo backend")
else:
device, gpus = torch.device("cuda"), [int(item) for item in gpus.split("-")]
n_gpus = len(gpus)
elif hasattr(torch, "xpu") and torch.xpu.is_available() and main_config.device.startswith("xpu"):
device, gpus = torch.device("xpu"), [int(item) for item in gpus.split("-")]
n_gpus = len(gpus)
elif opencl.is_available() and main_config.device.startswith("ocl"):
device, gpus = torch.device("ocl"), [int(item) for item in gpus.split("-")]
n_gpus = len(gpus)
elif directml.is_available() and main_config.device.startswith("privateuseone"):
device, gpus = torch.device("privateuseone"), [int(item) for item in gpus.split("-")]
n_gpus = len(gpus)
elif torch.backends.mps.is_available() and main_config.device.startswith("mps"):
device, gpus = torch.device("mps"), [0]
n_gpus = 1
else:
device, gpus = torch.device("cpu"), [0]
n_gpus = 1
logger.warning(translations["not_gpu"])
logger.info(
translations["use_precision"].format(
fp=("BF16" if getattr(main_config, 'brain', False) else "FP16") if is_half else "FP32"
)
)
def start():
children = []
pid_data = {"process_pids": []}
if save_the_pid:
with open(config_save_path, "r", encoding="utf-8") as f:
try:
pid_data.update(json.load(f))
except json.JSONDecodeError:
pass
for rank, device_id in enumerate(gpus):
subproc = mp.Process(
target=run,
args=(
rank,
n_gpus,
pretrainG,
pretrainD,
pitch_guidance,
total_epoch,
save_every_weights,
config,
device,
device_id,
model_author,
vocoder,
checkpointing,
energy_use,
compile_model,
fast_train,
)
)
children.append(subproc)
subproc.start()
pid_data["process_pids"].append(subproc.pid)
if save_the_pid:
with open(config_save_path, "w", encoding="utf-8") as f:
json.dump(pid_data, f, indent=4)
for i in range(n_gpus):
children[i].join()
def load_from_json(file_path):
if os.path.exists(file_path):
with open(file_path, "r") as f:
data = json.load(f)
return (
data.get("loss_disc_history", []),
data.get("smoothed_loss_disc_history", []),
data.get("loss_gen_history", []),
data.get("smoothed_loss_gen_history", [])
)
return [], [], [], []
def continue_overtrain_detector(training_file_path):
if overtraining_detector and os.path.exists(training_file_path):
(
loss_disc_history,
smoothed_loss_disc_history,
loss_gen_history,
smoothed_loss_gen_history
) = load_from_json(training_file_path)
if cleanup:
for root, dirs, files in os.walk(experiment_dir, topdown=False):
for name in files:
file_path = os.path.join(root, name)
file_name, file_extension = os.path.splitext(name)
if (
file_extension == ".0" or
(file_name.startswith(("D_", "G_")) and file_extension == ".pth") or
(file_name.startswith(("added", "trained")) and file_extension == ".index")
):
os.remove(file_path)
for name in dirs:
if name == "eval":
folder_path = os.path.join(root, name)
for item in os.listdir(folder_path):
item_path = os.path.join(folder_path, item)
if os.path.isfile(item_path): os.remove(item_path)
os.rmdir(folder_path)
continue_overtrain_detector(training_file_path)
start()
except Exception as e:
logger.error(f"{translations['training_error']} {e}")
import traceback
logger.debug(traceback.format_exc())
class EpochRecorder:
def __init__(self):
self.last_time = ttime()
def record(self):
now_time = ttime()
elapsed_time = now_time - self.last_time
self.last_time = now_time
return translations["time_or_speed_training"].format(
current_time=datetime.datetime.now().strftime("%H:%M:%S"),
elapsed_time_str=str(datetime.timedelta(seconds=int(round(elapsed_time, 1))))
)
def run(
rank,
n_gpus,
pretrainG,
pretrainD,
pitch_guidance,
custom_total_epoch,
custom_save_every_weights,
config,
device,
device_id,
model_author,
vocoder,
checkpointing,
energy_use,
compile_model,
fast_train=False,
):
global global_step, smoothed_value_gen, smoothed_value_disc, optimizer_choice
smoothed_value_gen, smoothed_value_disc = 0, 0
# DDP backend selection — Vietnamese-RVC style with XPU + Advanced-RVC ZLUDA
_ddp_backend = "gloo" if (sys.platform == "win32" or device.type not in ["cuda", "xpu"] or _is_zluda) else ("xccl" if device.type == "xpu" else "nccl")
dist.init_process_group(
backend=_ddp_backend,
init_method="env://",
world_size=n_gpus if device.type in ["cuda", "xpu"] else 1,
rank=rank if device.type in ["cuda", "xpu"] else 0
)
torch.manual_seed(config.train.seed)
if device.type == "cuda": torch.cuda.manual_seed(config.train.seed)
elif device.type == "xpu": torch.xpu.manual_seed(config.train.seed)
elif device.type == "ocl": opencl.pytorch_ocl.manual_seed_all(config.train.seed)
if torch.cuda.is_available(): torch.cuda.set_device(device_id)
elif hasattr(torch, "xpu") and torch.xpu.is_available(): torch.xpu.set_device(device_id)
if rank == 0:
if _is_zluda:
logger.info(f"Training on ZLUDA (AMD GPU): {zluda.device_name(0)}")
writer_eval = SummaryWriter(
log_dir=eval_dir
) if rank == 0 else None
from arvc.engine.training.runner.data_utils import (
DistributedBucketSampler,
TextAudioCollate,
TextAudioLoader
)
train_dataset = TextAudioLoader(
config.data,
spec_dirs=spec_dirs,
cache_spectrogram=cache_spectrogram,
pitch_guidance=pitch_guidance,
energy=energy_use
)
# Adaptive data loader settings — bumped under --fast_train for better
# I/O pipelining. These are vocal-quality-safe: they only affect how
# many batches are prefetched in parallel, never the math.
_pin_mem = not _is_zluda
if fast_train and not _is_zluda:
# FAST-TRAIN: more workers + larger prefetch factor → better overlap
# of CPU data loading with GPU compute. Capped to avoid CPU
# oversubscription on small machines.
import multiprocessing as _mp
_cpu = _mp.cpu_count() or 4
_num_workers = min(8, max(4, _cpu // 2))
_prefetch = 16
else:
_num_workers = 2 if _is_zluda else 4
_prefetch = 2 if _is_zluda else 8
if rank == 0 and fast_train:
logger.info(
f"FAST-TRAIN dataloader: num_workers={_num_workers}, prefetch_factor={_prefetch}, "
f"pin_memory={_pin_mem}, persistent_workers=True"
)
train_loader = DataLoader(
train_dataset,
num_workers=_num_workers,
shuffle=False,
pin_memory=_pin_mem,
batch_size=1 if architecture != "SVC" else batch_size,
collate_fn=TextAudioCollate(
pitch_guidance=pitch_guidance,
energy=energy_use
),
batch_sampler=DistributedBucketSampler(
train_dataset,
batch_size,
[50, 100, 200, 300, 400, 500, 600, 700, 800, 900],
num_replicas=n_gpus,
rank=rank,
shuffle=True
) if architecture != "SVC" else None,
persistent_workers=True,
prefetch_factor=_prefetch
)
if len(train_loader) < 3:
logger.warning(translations["not_enough_data"])
sys.exit(1)
# ── Dynamic spk_dim detection from checkpoint (Vietnamese-RVC feature) ──
spk_dim = config.model.spk_embed_dim
try:
spk_dim = config.sid
except Exception as e:
logger.debug(e)
try:
g_path = os.path.join(checkpoint_path, "G_latest.pth")
last_g = g_path if save_only_latest and os.path.exists(g_path) else latest_checkpoint_path(checkpoint_path, "G_*.pth")
chk_path = (last_g if last_g else (pretrainG if pretrainG not in ["", "None"] else None))
if chk_path:
from arvc.engine.models.safe_load import safe_torch_load
ckpt = safe_torch_load(chk_path)
spk_dim = ckpt["model"]["emb_g.weight"].shape[0]
del ckpt
except Exception as e:
logger.debug(e)
config.model.spk_embed_dim = spk_dim
from arvc.engine.models.algorithms.synthesizers import Synthesizer
from arvc.engine.models.algorithms.discriminators import MultiPeriodDiscriminator
# SVC architecture support (Vietnamese-RVC feature)
_has_svc = False
try:
from arvc.engine.models.algorithms.synthesizers import SynthesizerSVC
_has_svc = True
except ImportError:
pass
if architecture == "SVC" and _has_svc:
net_g, net_d = (
SynthesizerSVC(
config.data.filter_length // 2 + 1,
config.train.segment_size // config.data.hop_length,
**config.model,
sr=config.data.sample_rate,
vocoder=vocoder,
checkpointing=checkpointing,
),
MultiPeriodDiscriminator(
version=disc_version,
use_spectral_norm=config.model.use_spectral_norm,
checkpointing=checkpointing
)
)
else:
net_g, net_d = (
Synthesizer(
config.data.filter_length // 2 + 1,
config.train.segment_size // config.data.hop_length,
**config.model,
use_f0=pitch_guidance,
sr=config.data.sample_rate,
vocoder=vocoder,
randomized=randomized,
checkpointing=checkpointing,
energy=energy_use
),
MultiPeriodDiscriminator(
version=disc_version,
use_spectral_norm=config.model.use_spectral_norm,
checkpointing=checkpointing
)
)
# Move to device — Vietnamese-RVC style with XPU support
net_g, net_d = (
net_g.cuda(device_id),
net_d.cuda(device_id)
) if torch.cuda.is_available() else (
net_g.xpu(device_id),
net_d.xpu(device_id)
) if hasattr(torch, "xpu") and torch.xpu.is_available() else (
net_g.to(device),
net_d.to(device)
)
# ── Optimizer selection ──
# Use the Advanced-RVC optimizer registry when available, with Vietnamese-RVC
# fallbacks for AdaBeliefV2 / InverseSqrt scheduler
_use_registry = True
# SPEED PATCH (Applio parity): if --bf16_adamw was passed, force the
# optimizer to AnyPrecisionAdamW. The fast_train bundle above already
# set main_config.brain=True, so the autocast dtype will be bf16 and
# AnyPrecisionAdamW will keep fp32 master weights + bf16 momentum.
if bf16_adamw:
optimizer_choice = "AnyPrecisionAdamW"
if rank == 0:
logger.info(
"Applio-parity bf16_adamw: forcing optimizer=AnyPrecisionAdamW "
"with bf16 autocast (brain=True)."
)
try:
from arvc.engine.models.optimizers import get_optimizer_class, get_optimizer_info
except ImportError:
_use_registry = False
# Vietnamese-RVC style InverseSqrt scheduler import for AdaBeliefV2
get_inverse_sqrt_scheduler = None
try:
from arvc.engine.models.optimizers.adabeliefv2 import AdaBeliefV2 as _AdaBeliefV2, get_inverse_sqrt_scheduler as _get_inv_sqrt
get_inverse_sqrt_scheduler = _get_inv_sqrt
except ImportError:
pass
if _use_registry:
try:
optimizer_optim = get_optimizer_class(optimizer_choice)
optimizer_meta = get_optimizer_info(optimizer_choice)
except ValueError:
logger.warning(f"Unknown optimizer '{optimizer_choice}', falling back to AdamW")
optimizer_choice = "AdamW"
optimizer_optim = get_optimizer_class("AdamW")
optimizer_meta = get_optimizer_info("AdamW")
if rank == 0:
logger.info(f"Optimizer: {optimizer_choice} (Rating: {optimizer_meta.get('rating', 'N/A')}/5 - {optimizer_meta.get('category', 'N/A')})")
# CUDA Optimizer Training: Use fused kernels when available and supported
use_fused_optimizer = (
device.type == "cuda"
and not _is_zluda
and optimizer_meta.get("supports_fused", False)
and hasattr(optimizer_optim, "fused")
)
# Build optimizer kwargs based on what the optimizer supports
def _build_optimizer_kwargs(lr_coeff):
kwargs = {"lr": config.train.learning_rate * lr_coeff}
if optimizer_meta.get("supports_betas"):
kwargs["betas"] = config.train.betas
if optimizer_meta.get("supports_eps"):
kwargs["eps"] = config.train.eps
if optimizer_meta.get("supports_weight_decay"):
kwargs["weight_decay"] = 0.0
if use_fused_optimizer:
kwargs["fused"] = True
return kwargs
# 8-bit Adam (requires bitsandbytes) — Advanced-RVC feature
if use_8bit_adam and device.type == "cuda":
try:
import bitsandbytes as bnb
if rank == 0:
logger.info(f"Using 8-bit {optimizer_choice} via bitsandbytes for reduced VRAM usage")
optim_g = bnb.optim.AdamW8bit(net_g.parameters(), lr=config.train.learning_rate * g_lr_coeff, betas=config.train.betas)
optim_d = bnb.optim.AdamW8bit(net_d.parameters(), lr=config.train.learning_rate * d_lr_coeff, betas=config.train.betas)
use_fused_optimizer = False
except ImportError:
if rank == 0:
logger.warning("bitsandbytes not installed, falling back to standard optimizer")
optim_g = optimizer_optim(net_g.parameters(), **_build_optimizer_kwargs(g_lr_coeff))
optim_d = optimizer_optim(net_d.parameters(), **_build_optimizer_kwargs(d_lr_coeff))
else:
optim_g = optimizer_optim(net_g.parameters(), **_build_optimizer_kwargs(g_lr_coeff))
optim_d = optimizer_optim(net_d.parameters(), **_build_optimizer_kwargs(d_lr_coeff))
if rank == 0 and use_fused_optimizer:
logger.info(f"CUDA Optimizer Training: Using fused {optimizer_choice} for enhanced CUDA performance")
else:
# Vietnamese-RVC fallback optimizer selection
if optimizer_choice == "AnyPrecisionAdamW" and getattr(main_config, 'brain', False):
from arvc.engine.models.optimizers.anyprecision_optimizer import AnyPrecisionAdamW
optimizer_optim = AnyPrecisionAdamW
elif optimizer_choice == "RAdam":
from torch.optim import RAdam
optimizer_optim = RAdam
elif optimizer_choice == "AdaBelief":
from arvc.engine.models.optimizers.adabelief import AdaBelief
optimizer_optim = AdaBelief
elif optimizer_choice == "AdaBeliefV2":
from arvc.engine.models.optimizers.adabeliefv2 import AdaBeliefV2
optimizer_optim = AdaBeliefV2
else:
from torch.optim import AdamW
optimizer_optim = AdamW
optim_g, optim_d = (
optimizer_optim(
net_g.parameters(),
config.train.learning_rate * g_lr_coeff,
betas=config.train.betas if not optimizer_choice.startswith("AdaBelief") else 1e-8,
eps=config.train.eps
),
optimizer_optim(
net_d.parameters(),
config.train.learning_rate * d_lr_coeff,
betas=config.train.betas if not optimizer_choice.startswith("AdaBelief") else 1e-8,
eps=config.train.eps
)
)
fn_mel_loss = MultiScaleMelSpectrogramLoss(sample_rate=config.data.sample_rate) if multiscale_mel_loss else torch.nn.L1Loss()
# DDP wrapping — Vietnamese-RVC style with XPU, Advanced-RVC ZLUDA + bucket_cap_mb
if not device.type.startswith(("privateuseone", "ocl", "mps", "xpu")):
if _is_zluda:
# ZLUDA: DDP without device_ids (gloo backend, no NCCL)
net_g, net_d = DDP(net_g), DDP(net_d)
elif torch.cuda.is_available():
# Optimization: increase gradient bucket size for faster all-reduce communication
ddp_kwargs = {"device_ids": [device_id], "bucket_cap_mb": 25}
net_g, net_d = DDP(net_g, **ddp_kwargs), DDP(net_d, **ddp_kwargs)
else:
net_g, net_d = DDP(net_g), DDP(net_d)
# Optimization: torch.compile for PyTorch 2.x+ — Advanced-RVC feature
# ZLUDA: torch.compile is not supported
if compile_model and device.type == "cuda" and not _is_zluda and hasattr(torch, "compile"):
if rank == 0:
logger.info("Optimization: Applying torch.compile() (mode=reduce-overhead) to generator for faster training")
try:
net_g = torch.compile(net_g, mode="reduce-overhead")
except Exception as e:
if rank == 0:
logger.warning(f"torch.compile() on G failed, falling back to eager mode: {e}")
# FAST-TRAIN: also compile the discriminator. The discriminator
# runs every step (sometimes multiple times per G step), so fusing
# its kernels yields a real wall-clock speedup. Same math → vocal
# quality is unaffected.
if fast_train:
if rank == 0:
logger.info("FAST-TRAIN: also applying torch.compile() to discriminator")
try:
net_d = torch.compile(net_d, mode="reduce-overhead")
except Exception as e:
if rank == 0:
logger.warning(f"torch.compile() on D failed, falling back to eager mode: {e}")
scaler_dict = {}
try:
if rank == 0: logger.info(translations["start_training"])
d_path = os.path.join(checkpoint_path, "D_latest.pth") if save_only_latest else latest_checkpoint_path(checkpoint_path, "D_*.pth")
g_path = os.path.join(checkpoint_path, "G_latest.pth") if save_only_latest else latest_checkpoint_path(checkpoint_path, "G_*.pth")
_, _, _, epoch_str, scaler_dict = load_checkpoint(
logger,
d_path,
net_d,
optim_d
)
_, _, _, epoch_str, _ = load_checkpoint(
logger,
g_path,
net_g,
optim_g
)
if rank == 0: logger.info(translations["load_checkpoint"].format(d_path=d_path, g_path=g_path))
epoch_str += 1
global_step = (epoch_str - 1) * len(train_loader)
except (FileNotFoundError, RuntimeError, OSError, KeyError, ValueError) as e:
# SECURITY/RELIABILITY PATCH: was bare `except:` which silently swallowed
# ALL errors (including KeyboardInterrupt/SystemExit) and restarted training
# from epoch 1 — causing silent data loss on checkpoint corruption.
# Now we only catch the expected "no checkpoint yet" errors and log them;
# truly unexpected errors propagate so the user sees them.
check = ["", "None"]
epoch_str, global_step = 1, 0
if rank == 0:
logger.warning(
f"[checkpoint-load] No resumable checkpoint found or load failed ({type(e).__name__}: {e}). "
f"Starting training from epoch 1."
)
# Auto-download default pretrained models if no custom pretrained paths provided
# (Advanced-RVC feature — better than Vietnamese-RVC's approach)
if pretrainG in check and pretrainD in check and rank == 0:
# Primary and fallback pretrained URLs
primary_url = main_configs.get(
f"pretrained_{version}_url",
f"https://huggingface.co/buckets/R-Kentaren/Ultimate-RVC-Models/resolve/pretrained_{version}/"
)
# Fallback: R-Kentaren/Ultimate-RVC-Models HuggingFace Storage Bucket
_default_fallback = (
f"https://huggingface.co/buckets/R-Kentaren/Ultimate-RVC-Models/resolve/pretrained_{version}/"
)
fallback_url = main_configs.get(
f"pretrained_{version}_fallback_url",
_default_fallback
)
pretrained_save_dir = os.path.join(main_configs.get(f"pretrained_{version}_path", os.path.join(os.path.dirname(__file__), "../../assets/models", f"pretrained_{version}")))
os.makedirs(pretrained_save_dir, exist_ok=True)
pretrained_selector = {
True: { # pitch_guidance (f0 models)
24000: ("f0G24k.pth", "f0D24k.pth"),
32000: ("f0G32k.pth", "f0D32k.pth"),
40000: ("f0G40k.pth", "f0D40k.pth"),
44100: ("f0G40k.pth", "f0D40k.pth"), # reuse 40k pretrained
48000: ("f0G48k.pth", "f0D48k.pth"),
},
False: { # no pitch guidance (base models)
24000: ("G24k.pth", "D24k.pth"),
32000: ("G32k.pth", "D32k.pth"),
40000: ("G40k.pth", "D40k.pth"),