Skip to content

Configuration

config

DataConfig dataclass

Configuration for data loading, preprocessing, and tokenization.

Source code in src/quickmt_train/config.py
@dataclass
class DataConfig:
    """Configuration for data loading, preprocessing, and tokenization."""

    # Experiment info (usually populated from TrainConfig)
    experiment_name: str = "default"

    # Languages
    src_lang: str = "fa"
    tgt_lang: str = "en"

    # Paths
    corpora: list["CorpusConfig"] = None  # Will be initialized in __post_init__
    src_dev_path: str = "data/dev.fa"
    tgt_dev_path: str = "data/dev.en"

    def __post_init__(self):
        if self.corpora is None:
            self.corpora = []

    # Tokenizer
    char_coverage: float = 0.9999
    input_sentence_size: int = 10_000_000

    @property
    def tokenizer_prefix_src(self) -> str:
        return os.path.join(self.experiment_name, "tokenizer_src")

    @property
    def tokenizer_prefix_tgt(self) -> str:
        return os.path.join(self.experiment_name, "tokenizer_tgt")

    # Streaming & Batching
    max_tokens_per_batch: int = 6000
    buffer_size: int = 10000
    num_workers: int = 4
    prefetch_factor: int = 64
    pad_multiple: int = 8

    # N-best sampling
    src_spm_nbest_size: int = 1
    tgt_spm_nbest_size: int = 1
    src_spm_alpha: float = 0.0
    tgt_spm_alpha: float = 0.0

ExportConfig dataclass

Configuration for checkpoint averaging, quantization, and export.

Source code in src/quickmt_train/config.py
@dataclass
class ExportConfig:
    """Configuration for checkpoint averaging, quantization, and export."""

    # Averaging
    k: int = 5

    # Quantization
    export_int8: bool = False
    calib_batches: int = 200
    quantization: str = "int8"
    qconfig_backend: str = "fbgemm"  # "fbgemm" or "qnnpack"

    # Inference Defaults
    beam_size: int = 5
    max_len: int = 256
    batch_size: int = 32

    # CT2 specific
    add_source_bos: bool = True
    add_source_eos: bool = False

    # Experiment info (usually populated from TrainConfig)
    experiment_name: str = "default"

    @property
    def output_dir(self) -> str:
        return os.path.join(self.experiment_name, "exported_model")

    @property
    def src_vocab(self) -> str:
        return os.path.join(self.experiment_name, "tokenizer_src.vocab")

    @property
    def tgt_vocab(self) -> str:
        return os.path.join(self.experiment_name, "tokenizer_tgt.vocab")

    @property
    def output_prefix(self) -> str:
        return os.path.join(self.experiment_name, "averaged_model")

ModelConfig dataclass

Configuration for the Transformer model architecture.

Source code in src/quickmt_train/config.py
@dataclass
class ModelConfig:
    """Configuration for the Transformer model architecture."""

    d_model: int = 768
    enc_layers: int = 12
    dec_layers: int = 2
    n_heads: int = 16
    ffn_dim: int = 4096
    max_len: int = 512  # Hard filter during data loading
    dropout: float = 0.1
    vocab_size_src: int = 32000
    vocab_size_tgt: int = 32000
    use_checkpoint: bool = False
    ff_bias: bool = True
    layernorm_eps: float = 1e-6
    activation: str = "gelu"
    mlp_type: str = "standard"  # "standard" or "gated"
    norm_type: str = "layernorm"  # "layernorm" or "rmsnorm"
    tie_decoder_embeddings: bool = False

    # Special Tokens
    pad_id: int = 0
    unk_id: int = 1
    bos_id: int = 2
    eos_id: int = 3

TrainConfig dataclass

Configuration for the training loop and optimization.

Source code in src/quickmt_train/config.py
@dataclass
class TrainConfig:
    """Configuration for the training loop and optimization."""

    experiment_name: str = "default"
    aim_repo: str = "./aim-runs"

    # Optimizer
    lr: float = 1.0e-3
    weight_decay: float = 0.01
    adam_eps: float = 1e-6
    label_smoothing: float = 0.1
    adam_beta1: float = 0.9
    adam_beta2: float = 0.998

    # Scheduler
    scheduler_type: str = "inv_sqrt"  # "inv_sqrt" or "cosine"
    warmup_steps: int = 5000
    max_steps: int = 100000
    epochs: int = 20

    # Training Loop
    accum_steps: int = 30
    grad_clip: float = 1.0
    eval_steps: int = 1000
    max_checkpoints: int = 5
    save_checkpoints: bool = True

    # Hardware & Performance
    device: str = "cuda"  # "cuda", "cpu", or "auto"
    precision: str = "bf16"  # "bf16", "fp16", "fp32"
    tf32: bool = True

    # Logging & Validation
    log_steps: int = 100
    val_max_samples: int = 500
    quick_test_samples: int = 5

    # Checkpoint Resume
    resume_from: str = ""  # Path to .pt or .safetensors checkpoint
    reset_optimizer: bool = False  # Reset optimizer/scheduler state (for fine-tuning)

    # torch_compile
    enable_torch_compile: bool = True

    @property
    def checkpoint_dir(self) -> str:
        return os.path.join(self.experiment_name, "checkpoints")