Skip to content

Commit 99ba84b

Browse files
authored
Merge pull request #171 from cloneofsimo/develop
v0.1.6
2 parents 848db91 + a9354e6 commit 99ba84b

16 files changed

+801
-56
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
- Fine-tune Stable diffusion models twice as fast than dreambooth method, by Low-rank Adaptation
3838
- Get insanely small end result (1MB ~ 6MB), easy to share and download.
3939
- Compatible with `diffusers`
40+
- Support for inpainting
4041
- Sometimes _even better performance_ than full fine-tuning (but left as future work for extensive comparisons)
4142
- Merge checkpoints + Build recipes by merging LoRAs together
4243
- Pipeline to fine-tune CLIP + Unet + token to gain better results.
@@ -50,6 +51,10 @@
5051

5152
# UPDATES & Notes
5253

54+
### 2023/02/06
55+
56+
- Support for training inpainting on LoRA PTI. Use flag `--train-inpainting` with a inpainting stable diffusion base model (see `inpainting_example.sh`).
57+
5358
### 2023/02/01
5459

5560
- LoRA Joining is now available with `--mode=ljl` flag. Only three parameters are required : `path_to_lora1`, `path_to_lora2`, and `path_to_save`.

contents/inpainting_base_image.png

619 KB
Loading

contents/inpainting_mask.png

3.49 KB
Loading

contents/lora_pti_inpainting.jpg

31.6 KB
Loading
371 KB
Loading

example_loras/and.safetensors

11.8 MB
Binary file not shown.
5.92 MB
Binary file not shown.

lora_diffusion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .dataset import *
33
from .utils import *
44
from .preprocess_files import *
5+
from .lora_manager import *

lora_diffusion/cli_lora_add.py

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
collapse_lora,
1313
monkeypatch_remove_lora,
1414
)
15+
from .lora_manager import lora_join
1516
from .to_ckpt_v2 import convert_to_ckpt
1617

1718

@@ -20,53 +21,6 @@ def _text_lora_path(path: str) -> str:
2021
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
2122

2223

23-
def lora_join(lora_safetenors: list):
24-
metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors]
25-
total_metadata = {}
26-
total_tensor = {}
27-
total_rank = 0
28-
for _metadata in metadatas:
29-
rankset = []
30-
for k, v in _metadata.items():
31-
if k.endswith("rank"):
32-
rankset.append(int(v))
33-
34-
assert len(set(rankset)) == 1, "Rank should be the same per model"
35-
total_rank += rankset[0]
36-
total_metadata.update(_metadata)
37-
38-
tensorkeys = set()
39-
for safelora in lora_safetenors:
40-
tensorkeys.update(safelora.keys())
41-
42-
for keys in tensorkeys:
43-
if keys.startswith("text_encoder") or keys.startswith("unet"):
44-
tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors]
45-
46-
is_down = keys.endswith("down")
47-
48-
if is_down:
49-
_tensor = torch.cat(tensorset, dim=0)
50-
assert _tensor.shape[0] == total_rank
51-
else:
52-
_tensor = torch.cat(tensorset, dim=1)
53-
assert _tensor.shape[1] == total_rank
54-
55-
total_tensor[keys] = _tensor
56-
keys_rank = ":".join(keys.split(":")[:-1]) + ":rank"
57-
total_metadata[keys_rank] = str(total_rank)
58-
59-
for idx, safelora in enumerate(lora_safetenors):
60-
tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"]
61-
for jdx, token in enumerate(sorted(tokens)):
62-
del total_metadata[token]
63-
total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token)
64-
total_metadata[f"<s{idx}-{jdx}>"] = "<embed>"
65-
print(f"Embedding {token} replaced to <s{idx}-{jdx}>")
66-
67-
return total_tensor, total_metadata
68-
69-
7024
def add(
7125
path_1: str,
7226
path_2: str,
@@ -221,7 +175,7 @@ def add(
221175
safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
222176
safeloras_2 = safe_open(path_2, framework="pt", device="cpu")
223177

224-
total_tensor, total_metadata = lora_join([safeloras_1, safeloras_2])
178+
total_tensor, total_metadata, _, _ = lora_join([safeloras_1, safeloras_2])
225179
save_file(total_tensor, output_path, total_metadata)
226180

227181
else:

lora_diffusion/cli_lora_pti.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,60 @@ def collate_fn(examples):
168168

169169
return train_dataloader
170170

171+
def inpainting_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder):
172+
def collate_fn(examples):
173+
input_ids = [example["instance_prompt_ids"] for example in examples]
174+
pixel_values = [example["instance_images"] for example in examples]
175+
mask_values = [example["instance_masks"] for example in examples]
176+
masked_image_values = [example["instance_masked_images"] for example in examples]
177+
178+
# Concat class and instance examples for prior preservation.
179+
# We do this to avoid doing two forward passes.
180+
if examples[0].get("class_prompt_ids", None) is not None:
181+
input_ids += [example["class_prompt_ids"] for example in examples]
182+
pixel_values += [example["class_images"] for example in examples]
183+
mask_values += [example["class_masks"] for example in examples]
184+
masked_image_values += [example["class_masked_images"] for example in examples]
185+
186+
pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
187+
mask_values = torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
188+
masked_image_values = torch.stack(masked_image_values).to(memory_format=torch.contiguous_format).float()
189+
190+
input_ids = tokenizer.pad(
191+
{"input_ids": input_ids},
192+
padding="max_length",
193+
max_length=tokenizer.model_max_length,
194+
return_tensors="pt",
195+
).input_ids
196+
197+
batch = {
198+
"input_ids": input_ids,
199+
"pixel_values": pixel_values,
200+
"mask_values": mask_values,
201+
"masked_image_values": masked_image_values
202+
}
203+
204+
if examples[0].get("mask", None) is not None:
205+
batch["mask"] = torch.stack([example["mask"] for example in examples])
206+
207+
return batch
208+
209+
train_dataloader = torch.utils.data.DataLoader(
210+
train_dataset,
211+
batch_size=train_batch_size,
212+
shuffle=True,
213+
collate_fn=collate_fn,
214+
)
215+
216+
return train_dataloader
171217

172218
def loss_step(
173219
batch,
174220
unet,
175221
vae,
176222
text_encoder,
177223
scheduler,
224+
train_inpainting=False,
178225
t_mutliplier=1.0,
179226
mixed_precision=False,
180227
mask_temperature=1.0,
@@ -186,6 +233,16 @@ def loss_step(
186233
).latent_dist.sample()
187234
latents = latents * 0.18215
188235

236+
if train_inpainting:
237+
masked_image_latents = vae.encode(
238+
batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
239+
).latent_dist.sample()
240+
masked_image_latents = masked_image_latents * 0.18215
241+
mask = F.interpolate(
242+
batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
243+
scale_factor=1/8
244+
)
245+
189246
noise = torch.randn_like(latents)
190247
bsz = latents.shape[0]
191248

@@ -199,21 +256,26 @@ def loss_step(
199256

200257
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
201258

259+
if train_inpainting:
260+
latent_model_input = torch.cat([noisy_latents, mask, masked_image_latents], dim=1)
261+
else:
262+
latent_model_input = noisy_latents
263+
202264
if mixed_precision:
203265
with torch.cuda.amp.autocast():
204266

205267
encoder_hidden_states = text_encoder(
206268
batch["input_ids"].to(text_encoder.device)
207269
)[0]
208270

209-
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
271+
model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
210272
else:
211273

212274
encoder_hidden_states = text_encoder(
213275
batch["input_ids"].to(text_encoder.device)
214276
)[0]
215277

216-
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
278+
model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
217279

218280
if scheduler.config.prediction_type == "epsilon":
219281
target = noise
@@ -270,6 +332,7 @@ def train_inversion(
270332
log_wandb: bool = False,
271333
wandb_log_prompt_cnt: int = 10,
272334
class_token: str = "person",
335+
train_inpainting: bool = False,
273336
mixed_precision: bool = False,
274337
clip_ti_decay: bool = True,
275338
):
@@ -302,6 +365,7 @@ def train_inversion(
302365
vae,
303366
text_encoder,
304367
scheduler,
368+
train_inpainting=train_inpainting,
305369
mixed_precision=mixed_precision,
306370
)
307371
/ accum_iter
@@ -384,7 +448,7 @@ def train_inversion(
384448
# open all images in test_image_path
385449
images = []
386450
for file in os.listdir(test_image_path):
387-
if file.endswith(".png") or file.endswith(".jpg") or file.endswith(".jpeg"):
451+
if file.lower().endswith(".png") or file.lower().endswith(".jpg") or file.lower().endswith(".jpeg"):
388452
images.append(
389453
Image.open(os.path.join(test_image_path, file))
390454
)
@@ -429,6 +493,7 @@ def perform_tuning(
429493
log_wandb: bool = False,
430494
wandb_log_prompt_cnt: int = 10,
431495
class_token: str = "person",
496+
train_inpainting: bool = False,
432497
):
433498

434499
progress_bar = tqdm(range(num_steps))
@@ -457,6 +522,7 @@ def perform_tuning(
457522
vae,
458523
text_encoder,
459524
scheduler,
525+
train_inpainting=train_inpainting,
460526
t_mutliplier=0.8,
461527
mixed_precision=True,
462528
mask_temperature=mask_temperature,
@@ -565,6 +631,7 @@ def train(
565631
stochastic_attribute: Optional[str] = None,
566632
perform_inversion: bool = True,
567633
use_template: Literal[None, "object", "style"] = None,
634+
train_inpainting: bool = False,
568635
placeholder_tokens: str = "",
569636
placeholder_token_at_data: Optional[str] = None,
570637
initializer_tokens: Optional[str] = None,
@@ -716,13 +783,19 @@ def train(
716783
color_jitter=color_jitter,
717784
use_face_segmentation_condition=use_face_segmentation_condition,
718785
use_mask_captioned_data=use_mask_captioned_data,
786+
train_inpainting=train_inpainting,
719787
)
720788

721789
train_dataset.blur_amount = 200
722790

723-
train_dataloader = text2img_dataloader(
724-
train_dataset, train_batch_size, tokenizer, vae, text_encoder
725-
)
791+
if train_inpainting:
792+
train_dataloader = inpainting_dataloader(
793+
train_dataset, train_batch_size, tokenizer, vae, text_encoder
794+
)
795+
else:
796+
train_dataloader = text2img_dataloader(
797+
train_dataset, train_batch_size, tokenizer, vae, text_encoder
798+
)
726799

727800
index_no_updates = torch.arange(len(tokenizer)) != -1
728801

@@ -776,6 +849,7 @@ def train(
776849
log_wandb=log_wandb,
777850
wandb_log_prompt_cnt=wandb_log_prompt_cnt,
778851
class_token=class_token,
852+
train_inpainting=train_inpainting,
779853
mixed_precision=False,
780854
tokenizer=tokenizer,
781855
clip_ti_decay=clip_ti_decay,
@@ -883,6 +957,7 @@ def train(
883957
log_wandb=log_wandb,
884958
wandb_log_prompt_cnt=wandb_log_prompt_cnt,
885959
class_token=class_token,
960+
train_inpainting=train_inpainting,
886961
)
887962

888963

0 commit comments

Comments
 (0)