Creating a MLLM model
Creating a Multimodal LLM
Cornstarch supports creating a modular multimodal LLM from HuggingFace models.
For example, you can create a vision-language model (VLM) with Llama 8b and ViT as follows:
| from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.vit.modeling_vit import ViTPreTrainedModel
from cornstarch.models.multimodal_language_model import ModalEncoderModule, MultimodalModel
llm = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
vision_encoder = ViTPreTrainedModel.from_pretrained("openai/clip-vit-large-patch14")
mllm = MultimodalModel(
encoders={
"vision": ModalEncoderModule(vision_encoder),
},
language_model=llm,
)
|
where mllm
has language_model
and vision_encoder
modules.
Model Architecture
A simplified MultimodalModel
architecture is as follows:
| cornstarch.MultimodalModel
├── vision_encoder (cornstarch.ModalEncoderModule)
│ ├── module (transformers.PreTrainedModel)
│ └── projector (cornstarch.MultimodalProjector)
├── audio_encoder (cornstarch.ModalEncoderModule)
│ ├── module (transformers.PreTrainedModel)
│ └── projector (cornstarch.MultimodalProjector)
├── whatever_encoder (cornstarch.ModalEncoderModule)
│ ├── module (transformers.PreTrainedModel)
│ └── projector (cornstarch.MultimodalProjector)
└── language_model (transformers.PreTrainedModel)
|
It has one language_model
that represents the base LLM from HuggingFace transformers
, and can have arbitrary number of encoders.
Encoders are stored as a dictionary in MultimodalModel.encoders
where a key represents its encoder name (vision
, audio
, or whatever
in the example above) and the corresponding value is a ModalEncoderModule
.
ModalEncoderModule
is a single modality encoder that includes an encoder and a projector.
An encoder module is from HuggingFace transformers
, and Cornstarch provides the definition of the projector (cornstarch.models.multimodal_language_model.modeling_multimodal_language_model.MultiomdalProjector
).
Creating a Projector
Cornstarch provides two ways of creating MultimodalProjector
.
| class MultimodalProjector(PreTrainedModel):
def __init__(self, config: MultimodalProjectorConfig, projection: Optional[nn.Module] = None): ...
|
First, you can simply wrap your own torch.nn.Module
with MultimodalProjector
.
When you provide your module to projection
in creating a MultimodalProjector
instance, Cornstarch will use the given module as a projector module.
The generated projector module should explicitly be given to ModalEncoderModule
.
| wrapped_projector_module = MultiodalProjector(your_config, my_projector_module)
encoder = ModalEncoderModule(module=my_encoder_module, projector=wrapped_projector_module)
|
Second, Cornstarch can automatically initialize a new projector if no projector is given in ModalEncoderModule
:
| encoder = ModalEncoderModule(module=my_encoder_module)
# which is equivalent to
encoder = ModalEncoderModule(module=my_encoder_module, projector=None)
|
It adopts lazy initialization; a projector module is not initialized during creating a ModalEncoderModule
.
Instead, when a MultimodalModel
is created, it checks whether a projection module in the given MultimodalProjector
is None
, and creates a projector module if so.
MultimodalModel
accepts two arguments for projector creation as you want: init_projector_type
and init_activation
:
| class MultimodalModel(nn.Module):
def __init__(
self,
...,
init_projector_type: str = "linear",
init_activation: str = "gelu",
): ...
|
Currently MultimodalModel
accepts either linear
or mlp
as an init_projector_type
:
linear
: has a single torch.nn.Linear
layer.
mlp
: has two torch.nn.Linear
layers, where there is an activation layer as init_activation
type in the middle.
The type of activations that Cornstarch supports is defined in transformers.activations.ACT2CLS
.
Using PEFT
Cornstarch is compatible with HuggingFace PEFT.
Before passing models to MultimodalModel
, the model can be wrapped with peft (e.g. via get_peft_model
):
| from transformers.models.llama.modeling_llama import LlamaForCausalLM
from peft import LoraConfig, get_peft_model
from cornstarch.models.multimodal_language_model import MultimodalModel
llm = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False)
llm = get_peft_model(llm, peft_config)
mllm = MultimodalModel(..., language_model=llm)
|
Callback Interface
Cornstarch provides callback interface to add multimodal specific features without modifying underlying unimodal modules.
For example, Llava-Next has to choose an image feature based on its select strategy, which its base vision encoder such as CLIPVisionModel
does not have.
Callback is a great place to implement such features.
Cornstarch provides three types of callbacks in ModalEncoderModule
:
| ModalEncoderModule(
model=vision_encoder,
projector=vision_projector,
preprocess_callback=preprocess_vision_callback,
postprocess_module_callback=postprocess_vision_callback,
postprocess_projector_callback=postprocess_projector_callback,
)
|
The execution order of callbacks and modules is as follows:
| encoder: ModalEncoderModule
for encoder in mllm.encoders:
preprocessed_encoder_input = encoder.preprocess_callback(encoder_input)
encoder_output = encoder.module(preprocessed_encoder_input)
postprocessed_encoder_output = encoder.postprocess_module_callback(encoder_output)
module_output = encoder.projector(postprocessed_encoder_output)
postprocessed_module_output = encoder.postprocess_projector_callback(module_output)
# merge a list of postprocessed_module_outputs to text_embedding
merged_input = merge(postprocessed_module_outputs, language_inputs_embedding)
output = language_model(merged_input)
|
Inputs and outputs of each callback are as follows:
preprocess_callback(inputs: dict[str, Any]) -> dict[str, Any]
: gets the inputs of the modality encoder as a dictionary. Returns a modified dictionary which will be used as actual inputs of the modality encoder.
postprocess_encoder_callback(inputs: dict[str, Any], output: BaseModelOutput | tuple) -> BaseModelOutput | tuple
: gets the inputs and the output of the modality encoder. The output is either BaseModelOutput
or tuple
, depending on the encoder configuration return_dict
. Returns a modified output which will be forwarded to a projector.
postprocess_projector_callback(inputs: dict[str, Any], output: BaseModelOutput | tuple) -> BaseModelOutput | tuple
: gets the inputs of the modality encoder, and the output of the projector. The output is either BaseModelOutput
or tuple
, depending on the encoder configuration return_dict
. Returns a modified output which will be forwarded to the LLM.
A Llava-Next example of utilizing callback interface
The original LlavaNextForConditionalGeneration.forward()
is implemented as follows:
| class LlavaNextForConditionalGeneration:
def forward(
self,
...
):
...
if inputs_embeds is not None:
inputs_embeds = self.get_input_embeddings()(input_ids)
image_features = None
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)
# embed vision output result to inputs_embeds
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(...)
...
|
The highlighted Llava-next specific feature can be implemented in a callback:
| from typing import Optional
from transformers.modeling_outputs import BaseModelOutput, ModelOutput
def postprocess_projector_callback(
inputs: dict,
output: BaseModelOutput | tuple,
vision_feature_select_strategy: Optional[str] = None,
) -> BaseModelOutput | tuple:
pixel_values = inputs.get("pixel_values", None)
if pixel_values is not None and pixel_values.size(0) > 0:
# output[0] == output.last_hidden_state
image_features = output[0]
# pack_image_features function should be borrowed from
# the LlavaNextForConditionalGeneration class
image_features, feature_lens = pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=image_newline,
)
# replace output hidden state with postprocessed results
if isinstance(output, ModelOutput):
output.last_hidden_state = image_features
else:
output = (image_features,) + output[1:]
return output
|
which can be used for any combination of a vision encoder and an LLM:
Llava-Next with CLIP+Mistral using Cornstarch |
---|
| clip = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14")
mistral = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
mllm = MultimodalModel(
encoders={
"vision": ModalEncoderModule(
model=clip,
postprocess_projector_callback=postprocess_projector_callback,
)
},
language_model=mistral,
)
|
Llava-Next with Siglip+Llama using Cornstarch |
---|
| siglip = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384")
llama = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
mllm = MultimodalModel(
encoders={
"vision": ModalEncoderModule(
model=siglip,
postprocess_projector_callback=postprocess_projector_callback,
)
},
language_model=llama,
)
|