3.7 KiB
transformer package
The transformer package provides a customizable and configurable implementation of the transformer model architecture. Each component of a transformer stack, from entire layers down to individual linear layers, can be customized by swapping in different PyTorch modules using the "spec" parameters (see here). The configuration of the transformer (hidden size, number of layers, number of attention heads, etc.) is provided via a TransformerConfig object.
Submodules
transformer.attention module
This is the entire attention portion, either self or cross attention, of a transformer layer including the query, key, and value projections, a "core" attention calculation (e.g. dot product attention), and final output linear projection.
core.transformer.attention
transformer.dot_product_attention module
This is a PyTorch-only implementation of dot product attention. A more efficient implementation, like those provided by FlashAttention or CUDNN's FusedAttention, are typically used when training speed is important.
core.transformer.dot_product_attention
transformer.enums module
core.transformer.enums
transformer.identity_op module
This provides a pass-through module that can be used in specs to indicate that the operation should not be performed. For example, when using LayerNorm with the subsequent linear layer, an IdentityOp can be passed in as the LayerNorm module to use.
core.transformer.identity_op
transformer.mlp module
This is the entire MLP portion of the transformer layer with an input projection, non-linearity, and output projection.
core.transformer.mlp
transformer.module module
This provides a common base class for all modules used in the transformer that contains some common functionality.
core.transformer.module
transformer.transformer_block module
A block, or stack, of several transformer layers. The layers can all be the same or each can be unique.
core.transformer.transformer_block
transformer.transformer_config module
This contains all of the configuration options for the transformer. Using a dataclass reduces code bloat by keeping all arguments together in a dataclass instead of passing several arguments through multiple layers of function calls.
core.transformer.transformer_config
transformer.transformer_layer module
A single standard transformer layer including attention and MLP blocks.
core.transformer.transformer_layer
transformer.utils module
Various utilities used in the transformer implementation.
core.transformer.utils
Module contents
core.transformer