Using DDP/FSDP
PyTorch DDP and FSDP work by simply wrapping the original model with the API.
This design principle is also compatible with Cornstarch multimodal LLM, therefore DDP/FSDP can be used with Cornstarch.
An example of using PyTorch DDP |
---|
| import torch
import torch.distributed as dist
from torch.optim.adam import Adam
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
PreTrainedModel,
AutoTokenizer,
)
from transformers.models.clip import CLIPVisionModel, CLIPImageProcessor
from cornstarch.models.multimodal_language_model import (
ModalEncoderModule,
MultimodalModel,
MultimodalModelProcessor,
)
# Create a mllm
vision_model_name = "openai/clip-vit-base-patch32"
language_model_name = "meta-llama/Llama-3.2-1B"
vision_encoder = CLIPVisionModel.from_pretrained(vision_model_name)
language_model = AutoModelForCausalLM.from_pretrained(language_model_name)
model = MultimodalModel(
encoders={"vision": vision_encoder},
language_model=language_model,
).to(dtype=torch.bfloat16, device="cuda")
# Create a processor
image_processor = CLIPImageProcessor.from_pretrained(vision_model_name)
text_tokenizer = AutoTokenizer.from_pretrained(language_model_name, use_fase=True)
text_processor.pad_token_id = text_processor.eos_token_id
processor = MultimodalModelProcessor(
tokenizer=text_tokenizer,
image_processor=image_processor,
)
# Parallelize the model
dist.init_process_group()
ddp_model = DistributedDataParallel(model)
optimizer = Adam(ddp_model.parameters())
outputs = ddp_model(**inputs)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
|
Similarly, FSDP can be used by wrapping the model with torch.distributed._composable.fsdp.fully_shard()
:
Note
It is very important to properly define a wrapping unit using ModuleWrapPolicy
(FSDP1) or fully_shard
(FSDP2) in performance and correctness.
Parameters with different requires_grad
cannot be wrapped together, thus they need to be wrapped in a different group; otherwise it will raise an error.
Because of this, using FSDP still requires knowledge of model internal architecture.
An example of using PyTorch FSDP1 |
---|
| from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
# All the same before dist.init_process_group()
dist.init_process_group()
fsdp_model = FullyShardedDataParallel(
module=model, # required
auto_wrap_policy=ModuleWrapPolicy( # required
[
ModalEncoderModule,
MultimodalProjector,
torch.nn.Embedding,
CLIPEncoderLayer,
LlamaDecoderLayer,
]
),
sharding_strategy=ShardingStrateegy.FULL_SHARD, # optional
cpu_offload=CPUOffload(), # optional
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # optional
forward_prefetch=True, # optional
)
optimizer = Adam(fsdp_model.parameters())
outputs = fsdp_model(**inputs)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
|
An example of using PyTorch FSDP2 |
---|
| from torch.distributed._composable.fsdp import fully_shard
vision_encoder = CLIPVisionModel.from_pretrained(vision_model_name)
language_model = AutoModelForCausalLM.from_pretrained(language_model_name)
# The location of fully_shard() for subgroups does not have to be here.
for layer in vision_encoder.vision_model.encoder.layers:
fully_shard(layer)
fully_shard(vision_encoder.vision_model)
for layer in language_model.model.layers:
fully_shard(layer)
model = MultimodalModel(
encoders={"vision": vision_encoder},
language_model=language_model,
).to(dtype=torch.bfloat16, device="cuda")
fully_shard(model.vision_encoder.projector)
fsdp_model = fully_shard(model)
dist.init_process_group()
optimizer = Adam(fsdp_model.parameters())
outputs = fsdp_model(**inputs)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
|