@@ -168,13 +168,60 @@ def collate_fn(examples):
168168
169169 return train_dataloader
170170
171+ def inpainting_dataloader (train_dataset , train_batch_size , tokenizer , vae , text_encoder ):
172+ def collate_fn (examples ):
173+ input_ids = [example ["instance_prompt_ids" ] for example in examples ]
174+ pixel_values = [example ["instance_images" ] for example in examples ]
175+ mask_values = [example ["instance_masks" ] for example in examples ]
176+ masked_image_values = [example ["instance_masked_images" ] for example in examples ]
177+
178+ # Concat class and instance examples for prior preservation.
179+ # We do this to avoid doing two forward passes.
180+ if examples [0 ].get ("class_prompt_ids" , None ) is not None :
181+ input_ids += [example ["class_prompt_ids" ] for example in examples ]
182+ pixel_values += [example ["class_images" ] for example in examples ]
183+ mask_values += [example ["class_masks" ] for example in examples ]
184+ masked_image_values += [example ["class_masked_images" ] for example in examples ]
185+
186+ pixel_values = torch .stack (pixel_values ).to (memory_format = torch .contiguous_format ).float ()
187+ mask_values = torch .stack (mask_values ).to (memory_format = torch .contiguous_format ).float ()
188+ masked_image_values = torch .stack (masked_image_values ).to (memory_format = torch .contiguous_format ).float ()
189+
190+ input_ids = tokenizer .pad (
191+ {"input_ids" : input_ids },
192+ padding = "max_length" ,
193+ max_length = tokenizer .model_max_length ,
194+ return_tensors = "pt" ,
195+ ).input_ids
196+
197+ batch = {
198+ "input_ids" : input_ids ,
199+ "pixel_values" : pixel_values ,
200+ "mask_values" : mask_values ,
201+ "masked_image_values" : masked_image_values
202+ }
203+
204+ if examples [0 ].get ("mask" , None ) is not None :
205+ batch ["mask" ] = torch .stack ([example ["mask" ] for example in examples ])
206+
207+ return batch
208+
209+ train_dataloader = torch .utils .data .DataLoader (
210+ train_dataset ,
211+ batch_size = train_batch_size ,
212+ shuffle = True ,
213+ collate_fn = collate_fn ,
214+ )
215+
216+ return train_dataloader
171217
172218def loss_step (
173219 batch ,
174220 unet ,
175221 vae ,
176222 text_encoder ,
177223 scheduler ,
224+ train_inpainting = False ,
178225 t_mutliplier = 1.0 ,
179226 mixed_precision = False ,
180227 mask_temperature = 1.0 ,
@@ -186,6 +233,16 @@ def loss_step(
186233 ).latent_dist .sample ()
187234 latents = latents * 0.18215
188235
236+ if train_inpainting :
237+ masked_image_latents = vae .encode (
238+ batch ["masked_image_values" ].to (dtype = weight_dtype ).to (unet .device )
239+ ).latent_dist .sample ()
240+ masked_image_latents = masked_image_latents * 0.18215
241+ mask = F .interpolate (
242+ batch ["mask_values" ].to (dtype = weight_dtype ).to (unet .device ),
243+ scale_factor = 1 / 8
244+ )
245+
189246 noise = torch .randn_like (latents )
190247 bsz = latents .shape [0 ]
191248
@@ -199,21 +256,26 @@ def loss_step(
199256
200257 noisy_latents = scheduler .add_noise (latents , noise , timesteps )
201258
259+ if train_inpainting :
260+ latent_model_input = torch .cat ([noisy_latents , mask , masked_image_latents ], dim = 1 )
261+ else :
262+ latent_model_input = noisy_latents
263+
202264 if mixed_precision :
203265 with torch .cuda .amp .autocast ():
204266
205267 encoder_hidden_states = text_encoder (
206268 batch ["input_ids" ].to (text_encoder .device )
207269 )[0 ]
208270
209- model_pred = unet (noisy_latents , timesteps , encoder_hidden_states ).sample
271+ model_pred = unet (latent_model_input , timesteps , encoder_hidden_states ).sample
210272 else :
211273
212274 encoder_hidden_states = text_encoder (
213275 batch ["input_ids" ].to (text_encoder .device )
214276 )[0 ]
215277
216- model_pred = unet (noisy_latents , timesteps , encoder_hidden_states ).sample
278+ model_pred = unet (latent_model_input , timesteps , encoder_hidden_states ).sample
217279
218280 if scheduler .config .prediction_type == "epsilon" :
219281 target = noise
@@ -270,6 +332,7 @@ def train_inversion(
270332 log_wandb : bool = False ,
271333 wandb_log_prompt_cnt : int = 10 ,
272334 class_token : str = "person" ,
335+ train_inpainting : bool = False ,
273336 mixed_precision : bool = False ,
274337 clip_ti_decay : bool = True ,
275338):
@@ -302,6 +365,7 @@ def train_inversion(
302365 vae ,
303366 text_encoder ,
304367 scheduler ,
368+ train_inpainting = train_inpainting ,
305369 mixed_precision = mixed_precision ,
306370 )
307371 / accum_iter
@@ -384,7 +448,7 @@ def train_inversion(
384448 # open all images in test_image_path
385449 images = []
386450 for file in os .listdir (test_image_path ):
387- if file .endswith (".png" ) or file .endswith (".jpg" ) or file .endswith (".jpeg" ):
451+ if file .lower (). endswith (".png" ) or file .lower (). endswith (".jpg" ) or file . lower () .endswith (".jpeg" ):
388452 images .append (
389453 Image .open (os .path .join (test_image_path , file ))
390454 )
@@ -429,6 +493,7 @@ def perform_tuning(
429493 log_wandb : bool = False ,
430494 wandb_log_prompt_cnt : int = 10 ,
431495 class_token : str = "person" ,
496+ train_inpainting : bool = False ,
432497):
433498
434499 progress_bar = tqdm (range (num_steps ))
@@ -457,6 +522,7 @@ def perform_tuning(
457522 vae ,
458523 text_encoder ,
459524 scheduler ,
525+ train_inpainting = train_inpainting ,
460526 t_mutliplier = 0.8 ,
461527 mixed_precision = True ,
462528 mask_temperature = mask_temperature ,
@@ -565,6 +631,7 @@ def train(
565631 stochastic_attribute : Optional [str ] = None ,
566632 perform_inversion : bool = True ,
567633 use_template : Literal [None , "object" , "style" ] = None ,
634+ train_inpainting : bool = False ,
568635 placeholder_tokens : str = "" ,
569636 placeholder_token_at_data : Optional [str ] = None ,
570637 initializer_tokens : Optional [str ] = None ,
@@ -716,13 +783,19 @@ def train(
716783 color_jitter = color_jitter ,
717784 use_face_segmentation_condition = use_face_segmentation_condition ,
718785 use_mask_captioned_data = use_mask_captioned_data ,
786+ train_inpainting = train_inpainting ,
719787 )
720788
721789 train_dataset .blur_amount = 200
722790
723- train_dataloader = text2img_dataloader (
724- train_dataset , train_batch_size , tokenizer , vae , text_encoder
725- )
791+ if train_inpainting :
792+ train_dataloader = inpainting_dataloader (
793+ train_dataset , train_batch_size , tokenizer , vae , text_encoder
794+ )
795+ else :
796+ train_dataloader = text2img_dataloader (
797+ train_dataset , train_batch_size , tokenizer , vae , text_encoder
798+ )
726799
727800 index_no_updates = torch .arange (len (tokenizer )) != - 1
728801
@@ -776,6 +849,7 @@ def train(
776849 log_wandb = log_wandb ,
777850 wandb_log_prompt_cnt = wandb_log_prompt_cnt ,
778851 class_token = class_token ,
852+ train_inpainting = train_inpainting ,
779853 mixed_precision = False ,
780854 tokenizer = tokenizer ,
781855 clip_ti_decay = clip_ti_decay ,
@@ -883,6 +957,7 @@ def train(
883957 log_wandb = log_wandb ,
884958 wandb_log_prompt_cnt = wandb_log_prompt_cnt ,
885959 class_token = class_token ,
960+ train_inpainting = train_inpainting ,
886961 )
887962
888963
0 commit comments