Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions examples/dreambooth/train_dreambooth_lora_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,13 @@ def collate_fn(examples, with_prior_preservation=False):


class BucketBatchSampler(BatchSampler):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
def __init__(
self,
dataset: DreamBoothDataset,
batch_size: int,
drop_last: bool = False,
shuffle_batches_each_epoch: bool = True,
):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
Expand All @@ -983,6 +989,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch

# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
Expand All @@ -1004,9 +1011,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches

if not self.shuffle_batches_each_epoch:
# Shuffle the precomputed batches once to mix buckets while keeping
# the order stable across epochs for step-indexed caches.
random.shuffle(self.batches)

def __iter__(self):
# Shuffle the order of the batches each epoch
random.shuffle(self.batches)
if self.shuffle_batches_each_epoch:
random.shuffle(self.batches)
for batch in self.batches:
yield batch

Expand Down Expand Up @@ -1462,7 +1474,13 @@ def load_model_hook(models, input_dir):
center_crop=args.center_crop,
buckets=buckets,
)
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts
batch_sampler = BucketBatchSampler(
train_dataset,
batch_size=args.train_batch_size,
drop_last=True,
shuffle_batches_each_epoch=not has_step_indexed_caches,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
Expand Down Expand Up @@ -1579,8 +1597,7 @@ def _encode_single(prompt: str):
# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
if precompute_latents:
if has_step_indexed_caches:
prompt_embeds_cache = []
text_ids_cache = []
latents_cache = []
Expand Down
29 changes: 23 additions & 6 deletions examples/dreambooth/train_dreambooth_lora_flux2_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,13 @@ def collate_fn(examples):


class BucketBatchSampler(BatchSampler):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
def __init__(
self,
dataset: DreamBoothDataset,
batch_size: int,
drop_last: bool = False,
shuffle_batches_each_epoch: bool = True,
):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
Expand All @@ -981,6 +987,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch

# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
Expand All @@ -1002,9 +1009,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches

if not self.shuffle_batches_each_epoch:
# Shuffle the precomputed batches once to mix buckets while keeping
# the order stable across epochs for step-indexed caches.
random.shuffle(self.batches)

def __iter__(self):
# Shuffle the order of the batches each epoch
random.shuffle(self.batches)
if self.shuffle_batches_each_epoch:
random.shuffle(self.batches)
for batch in self.batches:
yield batch

Expand Down Expand Up @@ -1409,7 +1421,13 @@ def load_model_hook(models, input_dir):
center_crop=args.center_crop,
buckets=buckets,
)
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts
batch_sampler = BucketBatchSampler(
train_dataset,
batch_size=args.train_batch_size,
drop_last=True,
shuffle_batches_each_epoch=not has_step_indexed_caches,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
Expand Down Expand Up @@ -1512,8 +1530,7 @@ def _encode_single(prompt: str):
# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
if precompute_latents:
if has_step_indexed_caches:
prompt_embeds_cache = []
text_ids_cache = []
latents_cache = []
Expand Down
29 changes: 23 additions & 6 deletions examples/dreambooth/train_dreambooth_lora_flux2_klein.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,13 @@ def collate_fn(examples, with_prior_preservation=False):


class BucketBatchSampler(BatchSampler):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
def __init__(
self,
dataset: DreamBoothDataset,
batch_size: int,
drop_last: bool = False,
shuffle_batches_each_epoch: bool = True,
):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
Expand All @@ -978,6 +984,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch

# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
Expand All @@ -999,9 +1006,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches

if not self.shuffle_batches_each_epoch:
# Shuffle the precomputed batches once to mix buckets while keeping
# the order stable across epochs for step-indexed caches.
random.shuffle(self.batches)

def __iter__(self):
# Shuffle the order of the batches each epoch
random.shuffle(self.batches)
if self.shuffle_batches_each_epoch:
random.shuffle(self.batches)
for batch in self.batches:
yield batch

Expand Down Expand Up @@ -1455,7 +1467,13 @@ def load_model_hook(models, input_dir):
center_crop=args.center_crop,
buckets=buckets,
)
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts
batch_sampler = BucketBatchSampler(
train_dataset,
batch_size=args.train_batch_size,
drop_last=True,
shuffle_batches_each_epoch=not has_step_indexed_caches,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
Expand Down Expand Up @@ -1522,8 +1540,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
if precompute_latents:
if has_step_indexed_caches:
prompt_embeds_cache = []
text_ids_cache = []
latents_cache = []
Expand Down
29 changes: 23 additions & 6 deletions examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,13 @@ def collate_fn(examples):


class BucketBatchSampler(BatchSampler):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
def __init__(
self,
dataset: DreamBoothDataset,
batch_size: int,
drop_last: bool = False,
shuffle_batches_each_epoch: bool = True,
):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
Expand All @@ -977,6 +983,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch

# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
Expand All @@ -998,9 +1005,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches

if not self.shuffle_batches_each_epoch:
# Shuffle the precomputed batches once to mix buckets while keeping
# the order stable across epochs for step-indexed caches.
random.shuffle(self.batches)

def __iter__(self):
# Shuffle the order of the batches each epoch
random.shuffle(self.batches)
if self.shuffle_batches_each_epoch:
random.shuffle(self.batches)
for batch in self.batches:
yield batch

Expand Down Expand Up @@ -1403,7 +1415,13 @@ def load_model_hook(models, input_dir):
center_crop=args.center_crop,
buckets=buckets,
)
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts
batch_sampler = BucketBatchSampler(
train_dataset,
batch_size=args.train_batch_size,
drop_last=True,
shuffle_batches_each_epoch=not has_step_indexed_caches,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
Expand Down Expand Up @@ -1463,8 +1481,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
if precompute_latents:
if has_step_indexed_caches:
prompt_embeds_cache = []
text_ids_cache = []
latents_cache = []
Expand Down
29 changes: 23 additions & 6 deletions examples/dreambooth/train_dreambooth_lora_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,13 @@ def collate_fn(examples, with_prior_preservation=False):


class BucketBatchSampler(BatchSampler):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
def __init__(
self,
dataset: DreamBoothDataset,
batch_size: int,
drop_last: bool = False,
shuffle_batches_each_epoch: bool = True,
):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
Expand All @@ -972,6 +978,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch

# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
Expand All @@ -993,9 +1000,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches

if not self.shuffle_batches_each_epoch:
# Shuffle the precomputed batches once to mix buckets while keeping
# the order stable across epochs for step-indexed caches.
random.shuffle(self.batches)

def __iter__(self):
# Shuffle the order of the batches each epoch
random.shuffle(self.batches)
if self.shuffle_batches_each_epoch:
random.shuffle(self.batches)
for batch in self.batches:
yield batch

Expand Down Expand Up @@ -1449,7 +1461,13 @@ def load_model_hook(models, input_dir):
center_crop=args.center_crop,
buckets=buckets,
)
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts
batch_sampler = BucketBatchSampler(
train_dataset,
batch_size=args.train_batch_size,
drop_last=True,
shuffle_batches_each_epoch=not has_step_indexed_caches,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
Expand Down Expand Up @@ -1509,8 +1527,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
if precompute_latents:
if has_step_indexed_caches:
prompt_embeds_cache = []
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
Expand Down