Using DDP/FSDP
PyTorch DDP and FSDP work by simply wrapping the original model with the API.
This design principle is also compatible with Cornstarch multimodal LLM, therefore DDP/FSDP can be used with Cornstarch.
| An example of using PyTorch DDP | 
|---|
|  | import torch
import torch.distributed as dist
from torch.optim.adam import Adam
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    PreTrainedModel,
    AutoTokenizer,
)
from transformers.models.clip import CLIPVisionModel, CLIPImageProcessor
from cornstarch.models.multimodal_language_model import (
    ModalEncoderModule,
    MultimodalModel,
    MultimodalModelProcessor,
)
# Create a mllm
vision_model_name = "openai/clip-vit-base-patch32"
language_model_name = "meta-llama/Llama-3.2-1B"
vision_encoder = CLIPVisionModel.from_pretrained(vision_model_name)
language_model = AutoModelForCausalLM.from_pretrained(language_model_name)
model = MultimodalModel(
    encoders={"vision": vision_encoder},
    language_model=language_model,
).to(dtype=torch.bfloat16, device="cuda")
# Create a processor
image_processor = CLIPImageProcessor.from_pretrained(vision_model_name)
text_tokenizer = AutoTokenizer.from_pretrained(language_model_name, use_fase=True)
text_processor.pad_token_id = text_processor.eos_token_id
processor = MultimodalModelProcessor(
    tokenizer=text_tokenizer,
    image_processor=image_processor,
)
# Parallelize the model
dist.init_process_group()
ddp_model = DistributedDataParallel(model)
optimizer = Adam(ddp_model.parameters())
outputs = ddp_model(**inputs)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
 | 
Similarly, FSDP can be used by wrapping the model with torch.distributed._composable.fsdp.fully_shard():
Note
It is very important to properly define a wrapping unit using ModuleWrapPolicy (FSDP1) or fully_shard (FSDP2) in performance and correctness.
Parameters with different requires_grad cannot be wrapped together, thus they need to be wrapped in a different group; otherwise it will raise an error.
Because of this, using FSDP still requires knowledge of model internal architecture.
 
| An example of using PyTorch FSDP1 | 
|---|
|  | from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
# All the same before dist.init_process_group()
dist.init_process_group()
fsdp_model = FullyShardedDataParallel(
    module=model,                           # required
    auto_wrap_policy=ModuleWrapPolicy(      # required
        [
            ModalEncoderModule,
            MultimodalProjector,
            torch.nn.Embedding,
            CLIPEncoderLayer,               
            LlamaDecoderLayer,
        ]
    ),
    sharding_strategy=ShardingStrateegy.FULL_SHARD,  # optional
    cpu_offload=CPUOffload(),                        # optional
    backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # optional
    forward_prefetch=True,                           # optional
)
optimizer = Adam(fsdp_model.parameters())
outputs = fsdp_model(**inputs)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
 | 
| An example of using PyTorch FSDP2 | 
|---|
|  | from torch.distributed._composable.fsdp import fully_shard
vision_encoder = CLIPVisionModel.from_pretrained(vision_model_name)
language_model = AutoModelForCausalLM.from_pretrained(language_model_name)
# The location of fully_shard() for subgroups does not have to be here.
for layer in vision_encoder.vision_model.encoder.layers:
    fully_shard(layer)
fully_shard(vision_encoder.vision_model)
for layer in language_model.model.layers:
    fully_shard(layer)
model = MultimodalModel(
    encoders={"vision": vision_encoder},
    language_model=language_model,
).to(dtype=torch.bfloat16, device="cuda")
fully_shard(model.vision_encoder.projector)
fsdp_model = fully_shard(model)
dist.init_process_group()
optimizer = Adam(fsdp_model.parameters())
outputs = fsdp_model(**inputs)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
 |