Creating a MLLM model
Creating a Multimodal LLM
Cornstarch supports creating a modular multimodal LLM from HuggingFace models.
For example, you can create a vision-language model (VLM) with Llama 8b and ViT as follows:
| from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.vit.modeling_vit import ViTPreTrainedModel
from cornstarch.models.multimodal_language_model import ModalEncoderModule, MultimodalModel
llm = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
vision_encoder = ViTPreTrainedModel.from_pretrained("openai/clip-vit-large-patch14")
mllm = MultimodalModel(
encoders={
"vision": ModalEncoderModule(vision_encoder),
},
language_model=llm,
)
|
where mllm has language_model and vision_encoder modules.
Model Architecture
A simplified MultimodalModel architecture is as follows:
| cornstarch.MultimodalModel
├── vision_encoder (cornstarch.ModalEncoderModule)
│ ├── module (transformers.PreTrainedModel)
│ └── projector (cornstarch.MultimodalProjector)
├── audio_encoder (cornstarch.ModalEncoderModule)
│ ├── module (transformers.PreTrainedModel)
│ └── projector (cornstarch.MultimodalProjector)
├── whatever_encoder (cornstarch.ModalEncoderModule)
│ ├── module (transformers.PreTrainedModel)
│ └── projector (cornstarch.MultimodalProjector)
└── language_model (transformers.PreTrainedModel)
|
It has one language_model that represents the base LLM from HuggingFace transformers, and can have arbitrary number of encoders.
Encoders are stored as a dictionary in MultimodalModel.encoders where a key represents its encoder name (vision, audio, or whatever in the example above) and the corresponding value is a ModalEncoderModule.
ModalEncoderModule is a single modality encoder that includes an encoder and a projector.
An encoder module is from HuggingFace transformers, and Cornstarch provides the definition of the projector (cornstarch.models.multimodal_language_model.modeling_multimodal_language_model.MultiomdalProjector).
Creating a Projector
Cornstarch provides two ways of creating MultimodalProjector.
| class MultimodalProjector(PreTrainedModel):
def __init__(self, config: MultimodalProjectorConfig, projection: Optional[nn.Module] = None): ...
|
First, you can simply wrap your own torch.nn.Module with MultimodalProjector.
When you provide your module to projection in creating a MultimodalProjector instance, Cornstarch will use the given module as a projector module.
The generated projector module should explicitly be given to ModalEncoderModule.
| wrapped_projector_module = MultiodalProjector(your_config, my_projector_module)
encoder = ModalEncoderModule(module=my_encoder_module, projector=wrapped_projector_module)
|
Second, Cornstarch can automatically initialize a new projector if no projector is given in ModalEncoderModule:
| encoder = ModalEncoderModule(module=my_encoder_module)
# which is equivalent to
encoder = ModalEncoderModule(module=my_encoder_module, projector=None)
|
It adopts lazy initialization; a projector module is not initialized during creating a ModalEncoderModule.
Instead, when a MultimodalModel is created, it checks whether a projection module in the given MultimodalProjector is None, and creates a projector module if so.
MultimodalModel accepts two arguments for projector creation as you want: init_projector_type and init_activation:
| class MultimodalModel(nn.Module):
def __init__(
self,
...,
init_projector_type: str = "linear",
init_activation: str = "gelu",
): ...
|
Currently MultimodalModel accepts either linear or mlp as an init_projector_type:
linear: has a single torch.nn.Linear layer.
mlp: has two torch.nn.Linear layers, where there is an activation layer as init_activation type in the middle.
The type of activations that Cornstarch supports is defined in transformers.activations.ACT2CLS.
Using PEFT
Cornstarch is compatible with HuggingFace PEFT.
Before passing models to MultimodalModel, the model can be wrapped with peft (e.g. via get_peft_model):
| from transformers.models.llama.modeling_llama import LlamaForCausalLM
from peft import LoraConfig, get_peft_model
from cornstarch.models.multimodal_language_model import MultimodalModel
llm = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False)
llm = get_peft_model(llm, peft_config)
mllm = MultimodalModel(..., language_model=llm)
|
Callback Interface
Cornstarch provides callback interface to add multimodal specific features without modifying underlying unimodal modules.
For example, Llava-Next has to choose an image feature based on its select strategy, which its base vision encoder such as CLIPVisionModel does not have.
Callback is a great place to implement such features.
Cornstarch provides three types of callbacks in ModalEncoderModule:
| ModalEncoderModule(
model=vision_encoder,
projector=vision_projector,
preprocess_callback=preprocess_vision_callback,
postprocess_module_callback=postprocess_vision_callback,
postprocess_projector_callback=postprocess_projector_callback,
)
|
The execution order of callbacks and modules is as follows:
| encoder: ModalEncoderModule
for encoder in mllm.encoders:
preprocessed_encoder_input = encoder.preprocess_callback(encoder_input)
encoder_output = encoder.module(preprocessed_encoder_input)
postprocessed_encoder_output = encoder.postprocess_module_callback(encoder_output)
module_output = encoder.projector(postprocessed_encoder_output)
postprocessed_module_output = encoder.postprocess_projector_callback(module_output)
# merge a list of postprocessed_module_outputs to text_embedding
merged_input = merge(postprocessed_module_outputs, language_inputs_embedding)
output = language_model(merged_input)
|
Inputs and outputs of each callback are as follows:
preprocess_callback(inputs: dict[str, Any]) -> dict[str, Any]: gets the inputs of the modality encoder as a dictionary. Returns a modified dictionary which will be used as actual inputs of the modality encoder.
postprocess_encoder_callback(inputs: dict[str, Any], output: BaseModelOutput | tuple) -> BaseModelOutput | tuple: gets the inputs and the output of the modality encoder. The output is either BaseModelOutput or tuple, depending on the encoder configuration return_dict. Returns a modified output which will be forwarded to a projector.
postprocess_projector_callback(inputs: dict[str, Any], output: BaseModelOutput | tuple) -> BaseModelOutput | tuple: gets the inputs of the modality encoder, and the output of the projector. The output is either BaseModelOutput or tuple, depending on the encoder configuration return_dict. Returns a modified output which will be forwarded to the LLM.
A Llava-Next example of utilizing callback interface
The original LlavaNextForConditionalGeneration.forward() is implemented as follows:
| class LlavaNextForConditionalGeneration:
def forward(
self,
...
):
...
if inputs_embeds is not None:
inputs_embeds = self.get_input_embeddings()(input_ids)
image_features = None
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)
# embed vision output result to inputs_embeds
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(...)
...
|
The highlighted Llava-next specific feature can be implemented in a callback:
| from typing import Optional
from transformers.modeling_outputs import BaseModelOutput, ModelOutput
def postprocess_projector_callback(
inputs: dict,
output: BaseModelOutput | tuple,
vision_feature_select_strategy: Optional[str] = None,
) -> BaseModelOutput | tuple:
pixel_values = inputs.get("pixel_values", None)
if pixel_values is not None and pixel_values.size(0) > 0:
# output[0] == output.last_hidden_state
image_features = output[0]
# pack_image_features function should be borrowed from
# the LlavaNextForConditionalGeneration class
image_features, feature_lens = pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=image_newline,
)
# replace output hidden state with postprocessed results
if isinstance(output, ModelOutput):
output.last_hidden_state = image_features
else:
output = (image_features,) + output[1:]
return output
|
which can be used for any combination of a vision encoder and an LLM:
| Llava-Next with CLIP+Mistral using Cornstarch |
|---|
| clip = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14")
mistral = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
mllm = MultimodalModel(
encoders={
"vision": ModalEncoderModule(
model=clip,
postprocess_projector_callback=postprocess_projector_callback,
)
},
language_model=mistral,
)
|
| Llava-Next with Siglip+Llama using Cornstarch |
|---|
| siglip = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384")
llama = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
mllm = MultimodalModel(
encoders={
"vision": ModalEncoderModule(
model=siglip,
postprocess_projector_callback=postprocess_projector_callback,
)
},
language_model=llama,
)
|