From da85c1ce7e76d660553391b9c3d46d96528c4fe7 Mon Sep 17 00:00:00 2001 From: Alexey Zolotenkov Date: Fri, 27 Mar 2026 11:52:39 +0100 Subject: [PATCH 1/4] Fix bucket sampler cache alignment in DreamBooth scripts --- examples/dreambooth/train_dreambooth_lora_flux2.py | 3 +-- examples/dreambooth/train_dreambooth_lora_flux2_img2img.py | 3 +-- examples/dreambooth/train_dreambooth_lora_flux2_klein.py | 3 +-- .../dreambooth/train_dreambooth_lora_flux2_klein_img2img.py | 3 +-- examples/dreambooth/train_dreambooth_lora_z_image.py | 3 +-- 5 files changed, 5 insertions(+), 10 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 24d098add017..9d9422189997 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -1005,8 +1005,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.sampler_len += 1 # Count the number of batches def __iter__(self): - # Shuffle the order of the batches each epoch - random.shuffle(self.batches) + # Keep the precomputed batch order stable so step-indexed caches stay aligned. for batch in self.batches: yield batch diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index e18909e6dfd7..9fe0e81d921d 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1003,8 +1003,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.sampler_len += 1 # Count the number of batches def __iter__(self): - # Shuffle the order of the batches each epoch - random.shuffle(self.batches) + # Keep the precomputed batch order stable so step-indexed caches stay aligned. for batch in self.batches: yield batch diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 268d0148e446..39f1b2ca7858 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -1000,8 +1000,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.sampler_len += 1 # Count the number of batches def __iter__(self): - # Shuffle the order of the batches each epoch - random.shuffle(self.batches) + # Keep the precomputed batch order stable so step-indexed caches stay aligned. for batch in self.batches: yield batch diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index 0205f2e9e65f..aa13e84bcf74 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -999,8 +999,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.sampler_len += 1 # Count the number of batches def __iter__(self): - # Shuffle the order of the batches each epoch - random.shuffle(self.batches) + # Keep the precomputed batch order stable so step-indexed caches stay aligned. for batch in self.batches: yield batch diff --git a/examples/dreambooth/train_dreambooth_lora_z_image.py b/examples/dreambooth/train_dreambooth_lora_z_image.py index 623ae4d2aca3..a2e84fa1b692 100644 --- a/examples/dreambooth/train_dreambooth_lora_z_image.py +++ b/examples/dreambooth/train_dreambooth_lora_z_image.py @@ -994,8 +994,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.sampler_len += 1 # Count the number of batches def __iter__(self): - # Shuffle the order of the batches each epoch - random.shuffle(self.batches) + # Keep the precomputed batch order stable so step-indexed caches stay aligned. for batch in self.batches: yield batch From 7dccd529c87756b46dad0535d0d18f9f65e7d395 Mon Sep 17 00:00:00 2001 From: Alexey Zolotenkov Date: Fri, 27 Mar 2026 12:37:55 +0100 Subject: [PATCH 2/4] Shuffle precomputed DreamBooth bucket batches once --- examples/dreambooth/train_dreambooth_lora_flux2.py | 4 ++++ examples/dreambooth/train_dreambooth_lora_flux2_img2img.py | 4 ++++ examples/dreambooth/train_dreambooth_lora_flux2_klein.py | 4 ++++ .../dreambooth/train_dreambooth_lora_flux2_klein_img2img.py | 4 ++++ examples/dreambooth/train_dreambooth_lora_z_image.py | 4 ++++ 5 files changed, 20 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 9d9422189997..bb3767dff832 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -1004,6 +1004,10 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + def __iter__(self): # Keep the precomputed batch order stable so step-indexed caches stay aligned. for batch in self.batches: diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 9fe0e81d921d..8534181e35d3 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1002,6 +1002,10 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + def __iter__(self): # Keep the precomputed batch order stable so step-indexed caches stay aligned. for batch in self.batches: diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 39f1b2ca7858..a76fa54a0645 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -999,6 +999,10 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + def __iter__(self): # Keep the precomputed batch order stable so step-indexed caches stay aligned. for batch in self.batches: diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index aa13e84bcf74..50fd907a5974 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -998,6 +998,10 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + def __iter__(self): # Keep the precomputed batch order stable so step-indexed caches stay aligned. for batch in self.batches: diff --git a/examples/dreambooth/train_dreambooth_lora_z_image.py b/examples/dreambooth/train_dreambooth_lora_z_image.py index a2e84fa1b692..0351a934fe51 100644 --- a/examples/dreambooth/train_dreambooth_lora_z_image.py +++ b/examples/dreambooth/train_dreambooth_lora_z_image.py @@ -993,6 +993,10 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + def __iter__(self): # Keep the precomputed batch order stable so step-indexed caches stay aligned. for batch in self.batches: From 04c6304eb3b3f58a11bcb606e8cf07d26999429f Mon Sep 17 00:00:00 2001 From: Alexey Zolotenkov Date: Mon, 30 Mar 2026 13:46:26 +0200 Subject: [PATCH 3/4] Scope stable bucket ordering to cached DreamBooth batches --- .../dreambooth/train_dreambooth_lora_flux2.py | 26 +++++++++++++------ .../train_dreambooth_lora_flux2_img2img.py | 26 +++++++++++++------ .../train_dreambooth_lora_flux2_klein.py | 26 +++++++++++++------ ...ain_dreambooth_lora_flux2_klein_img2img.py | 26 +++++++++++++------ .../train_dreambooth_lora_z_image.py | 26 +++++++++++++------ 5 files changed, 90 insertions(+), 40 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index bb3767dff832..36fb546a1751 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -974,7 +974,9 @@ def collate_fn(examples, with_prior_preservation=False): class BucketBatchSampler(BatchSampler): - def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + def __init__( + self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True + ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -983,6 +985,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -1004,12 +1007,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches - # Shuffle the precomputed batches once to mix buckets while keeping - # the order stable across epochs for step-indexed caches. - random.shuffle(self.batches) + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) def __iter__(self): - # Keep the precomputed batch order stable so step-indexed caches stay aligned. + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) for batch in self.batches: yield batch @@ -1465,7 +1470,13 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1582,8 +1593,7 @@ def _encode_single(prompt: str): # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. - precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts - if precompute_latents: + if has_step_indexed_caches: prompt_embeds_cache = [] text_ids_cache = [] latents_cache = [] diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 8534181e35d3..9cec2c55aa11 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -972,7 +972,9 @@ def collate_fn(examples): class BucketBatchSampler(BatchSampler): - def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + def __init__( + self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True + ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -981,6 +983,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -1002,12 +1005,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches - # Shuffle the precomputed batches once to mix buckets while keeping - # the order stable across epochs for step-indexed caches. - random.shuffle(self.batches) + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) def __iter__(self): - # Keep the precomputed batch order stable so step-indexed caches stay aligned. + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) for batch in self.batches: yield batch @@ -1412,7 +1417,13 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1515,8 +1526,7 @@ def _encode_single(prompt: str): # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. - precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts - if precompute_latents: + if has_step_indexed_caches: prompt_embeds_cache = [] text_ids_cache = [] latents_cache = [] diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index a76fa54a0645..994239c428b0 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -969,7 +969,9 @@ def collate_fn(examples, with_prior_preservation=False): class BucketBatchSampler(BatchSampler): - def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + def __init__( + self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True + ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -978,6 +980,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -999,12 +1002,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches - # Shuffle the precomputed batches once to mix buckets while keeping - # the order stable across epochs for step-indexed caches. - random.shuffle(self.batches) + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) def __iter__(self): - # Keep the precomputed batch order stable so step-indexed caches stay aligned. + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) for batch in self.batches: yield batch @@ -1458,7 +1463,13 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1525,8 +1536,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. - precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts - if precompute_latents: + if has_step_indexed_caches: prompt_embeds_cache = [] text_ids_cache = [] latents_cache = [] diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index 50fd907a5974..877b5f9316f3 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -968,7 +968,9 @@ def collate_fn(examples): class BucketBatchSampler(BatchSampler): - def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + def __init__( + self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True + ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -977,6 +979,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -998,12 +1001,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches - # Shuffle the precomputed batches once to mix buckets while keeping - # the order stable across epochs for step-indexed caches. - random.shuffle(self.batches) + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) def __iter__(self): - # Keep the precomputed batch order stable so step-indexed caches stay aligned. + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) for batch in self.batches: yield batch @@ -1406,7 +1411,13 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1466,8 +1477,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. - precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts - if precompute_latents: + if has_step_indexed_caches: prompt_embeds_cache = [] text_ids_cache = [] latents_cache = [] diff --git a/examples/dreambooth/train_dreambooth_lora_z_image.py b/examples/dreambooth/train_dreambooth_lora_z_image.py index 0351a934fe51..13a90e8f3b54 100644 --- a/examples/dreambooth/train_dreambooth_lora_z_image.py +++ b/examples/dreambooth/train_dreambooth_lora_z_image.py @@ -963,7 +963,9 @@ def collate_fn(examples, with_prior_preservation=False): class BucketBatchSampler(BatchSampler): - def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + def __init__( + self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True + ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -972,6 +974,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -993,12 +996,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches - # Shuffle the precomputed batches once to mix buckets while keeping - # the order stable across epochs for step-indexed caches. - random.shuffle(self.batches) + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) def __iter__(self): - # Keep the precomputed batch order stable so step-indexed caches stay aligned. + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) for batch in self.batches: yield batch @@ -1452,7 +1457,13 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1512,8 +1523,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. - precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts - if precompute_latents: + if has_step_indexed_caches: prompt_embeds_cache = [] latents_cache = [] for batch in tqdm(train_dataloader, desc="Caching latents"): From 3cabf561296f048af3fa4f253cd3d6fdaab01597 Mon Sep 17 00:00:00 2001 From: Alexey Zolotenkov Date: Mon, 30 Mar 2026 13:56:19 +0200 Subject: [PATCH 4/4] Format DreamBooth bucket sampler updates --- examples/dreambooth/train_dreambooth_lora_flux2.py | 6 +++++- examples/dreambooth/train_dreambooth_lora_flux2_img2img.py | 6 +++++- examples/dreambooth/train_dreambooth_lora_flux2_klein.py | 6 +++++- .../dreambooth/train_dreambooth_lora_flux2_klein_img2img.py | 6 +++++- examples/dreambooth/train_dreambooth_lora_z_image.py | 6 +++++- 5 files changed, 25 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 36fb546a1751..a3be2efb7ea7 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -975,7 +975,11 @@ def collate_fn(examples, with_prior_preservation=False): class BucketBatchSampler(BatchSampler): def __init__( - self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 9cec2c55aa11..1d012aea5ce4 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -973,7 +973,11 @@ def collate_fn(examples): class BucketBatchSampler(BatchSampler): def __init__( - self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 994239c428b0..d54f14ffbfce 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -970,7 +970,11 @@ def collate_fn(examples, with_prior_preservation=False): class BucketBatchSampler(BatchSampler): def __init__( - self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index 877b5f9316f3..1ed9f35ecf1c 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -969,7 +969,11 @@ def collate_fn(examples): class BucketBatchSampler(BatchSampler): def __init__( - self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) diff --git a/examples/dreambooth/train_dreambooth_lora_z_image.py b/examples/dreambooth/train_dreambooth_lora_z_image.py index 13a90e8f3b54..f5b10a1059c4 100644 --- a/examples/dreambooth/train_dreambooth_lora_z_image.py +++ b/examples/dreambooth/train_dreambooth_lora_z_image.py @@ -964,7 +964,11 @@ def collate_fn(examples, with_prior_preservation=False): class BucketBatchSampler(BatchSampler): def __init__( - self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))