diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 24d098add017..a3be2efb7ea7 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -974,7 +974,13 @@ def collate_fn(examples, with_prior_preservation=False): class BucketBatchSampler(BatchSampler): - def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + def __init__( + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, + ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -983,6 +989,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -1004,9 +1011,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + def __iter__(self): - # Shuffle the order of the batches each epoch - random.shuffle(self.batches) + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) for batch in self.batches: yield batch @@ -1462,7 +1474,13 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1579,8 +1597,7 @@ def _encode_single(prompt: str): # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. - precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts - if precompute_latents: + if has_step_indexed_caches: prompt_embeds_cache = [] text_ids_cache = [] latents_cache = [] diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index e18909e6dfd7..1d012aea5ce4 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -972,7 +972,13 @@ def collate_fn(examples): class BucketBatchSampler(BatchSampler): - def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + def __init__( + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, + ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -981,6 +987,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -1002,9 +1009,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + def __iter__(self): - # Shuffle the order of the batches each epoch - random.shuffle(self.batches) + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) for batch in self.batches: yield batch @@ -1409,7 +1421,13 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1512,8 +1530,7 @@ def _encode_single(prompt: str): # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. - precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts - if precompute_latents: + if has_step_indexed_caches: prompt_embeds_cache = [] text_ids_cache = [] latents_cache = [] diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 268d0148e446..d54f14ffbfce 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -969,7 +969,13 @@ def collate_fn(examples, with_prior_preservation=False): class BucketBatchSampler(BatchSampler): - def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + def __init__( + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, + ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -978,6 +984,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -999,9 +1006,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + def __iter__(self): - # Shuffle the order of the batches each epoch - random.shuffle(self.batches) + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) for batch in self.batches: yield batch @@ -1455,7 +1467,13 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1522,8 +1540,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. - precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts - if precompute_latents: + if has_step_indexed_caches: prompt_embeds_cache = [] text_ids_cache = [] latents_cache = [] diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index 0205f2e9e65f..1ed9f35ecf1c 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -968,7 +968,13 @@ def collate_fn(examples): class BucketBatchSampler(BatchSampler): - def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + def __init__( + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, + ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -977,6 +983,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -998,9 +1005,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + def __iter__(self): - # Shuffle the order of the batches each epoch - random.shuffle(self.batches) + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) for batch in self.batches: yield batch @@ -1403,7 +1415,13 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1463,8 +1481,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. - precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts - if precompute_latents: + if has_step_indexed_caches: prompt_embeds_cache = [] text_ids_cache = [] latents_cache = [] diff --git a/examples/dreambooth/train_dreambooth_lora_z_image.py b/examples/dreambooth/train_dreambooth_lora_z_image.py index 623ae4d2aca3..f5b10a1059c4 100644 --- a/examples/dreambooth/train_dreambooth_lora_z_image.py +++ b/examples/dreambooth/train_dreambooth_lora_z_image.py @@ -963,7 +963,13 @@ def collate_fn(examples, with_prior_preservation=False): class BucketBatchSampler(BatchSampler): - def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + def __init__( + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, + ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -972,6 +978,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -993,9 +1000,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + def __iter__(self): - # Shuffle the order of the batches each epoch - random.shuffle(self.batches) + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) for batch in self.batches: yield batch @@ -1449,7 +1461,13 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1509,8 +1527,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. - precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts - if precompute_latents: + if has_step_indexed_caches: prompt_embeds_cache = [] latents_cache = [] for batch in tqdm(train_dataloader, desc="Caching latents"):