Skip to content

Training

train

setup_dist(train_cfg)

Initialize distributed training environment. Returns (rank, local_rank, world_size)

Source code in src/quickmt_train/train.py
def setup_dist(train_cfg):
    """
    Initialize distributed training environment.
    Returns (rank, local_rank, world_size)
    """
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        # Launched via torchrun
        rank = int(os.environ["RANK"])
        local_rank = int(os.environ["LOCAL_RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        dist.init_process_group("nccl")
        torch.cuda.set_device(local_rank)
        return rank, local_rank, world_size
    else:
        # Single process
        return 0, 0, 1

train_cli(config, **kwargs)

Train a Transformer model.

Parameters:

Name Type Description Default
config str

Path to config file

required
**kwargs

Overrides for configuration parameters (e.g., --max_steps 100)

{}
Source code in src/quickmt_train/train.py
def train_cli(config: str, **kwargs):
    """
    Train a Transformer model.

    Args:
        config: Path to config file
        **kwargs: Overrides for configuration parameters (e.g., --max_steps 100)
    """
    from .config import load_config

    model_cfg, data_cfg, train_cfg, _ = load_config(config)

    # Apply overrides
    for key, value in kwargs.items():
        if hasattr(train_cfg, key):
            setattr(train_cfg, key, value)
        elif hasattr(model_cfg, key):
            setattr(model_cfg, key, value)
        elif hasattr(data_cfg, key):
            setattr(data_cfg, key, value)
        else:
            print(f"Warning: Configuration key '{key}' not found in any config object.")

    # Make experiment folder if not exists
    os.makedirs(train_cfg.experiment_name, exist_ok=True)

    # Copy config to experiment folder
    copyfile(config, os.path.join(train_cfg.experiment_name, "config.yaml"))  # type: ignore

    train(model_cfg, data_cfg, train_cfg)

validate(model, loader, src_sp, tgt_sp, device, train_cfg, data_cfg, model_cfg, get_time_info, use_autoregressive=False)

Validate the model.

Source code in src/quickmt_train/train.py
def validate(
    model,
    loader,
    src_sp,
    tgt_sp,
    device,
    train_cfg,
    data_cfg,
    model_cfg,
    get_time_info,
    use_autoregressive=False,
):
    """
    Validate the model.
    """
    model.eval()
    total_loss_sum = 0
    total_tokens = 0
    correct_tokens = 0

    # Limit samples for BLEU calculation to reduce memory
    max_samples = train_cfg.val_max_samples
    hypotheses = []
    references = []
    sample_count = 0

    # Use inference_mode instead of no_grad for better performance
    autocast_dtype = torch.float32
    if device.type == "cuda":
        if train_cfg.precision in ("bf16", "bfloat16"):
            autocast_dtype = torch.bfloat16
        elif train_cfg.precision in ("fp16", "float16"):
            autocast_dtype = torch.float16

    with torch.inference_mode():
        for batch_idx, (src, tgt) in enumerate(loader):
            src, tgt = (
                src.to(device, non_blocking=True),
                tgt.to(device, non_blocking=True),
            )

            # Forward pass for loss and logits (calculates loss internally)
            with torch.autocast(device_type=device.type, dtype=autocast_dtype):
                loss_sum, (logits, num_tokens_batch) = model(
                    src, tgt, return_outputs=True
                )

            # Handle DataParallel output (vectors per GPU)
            if loss_sum.ndim > 0:
                loss_sum = loss_sum.sum()
            if num_tokens_batch.ndim > 0:
                num_tokens_batch = num_tokens_batch.sum()

            # Accumulate loss and tokens
            total_loss_sum += loss_sum.item()
            total_tokens += num_tokens_batch.item()

            # Accuracy calculation
            tgt_labels = tgt[:, 1:]
            preds = logits.argmax(dim=-1)
            mask_acc = tgt_labels != model_cfg.pad_id
            correct_tokens += ((preds == tgt_labels) & mask_acc).sum().item()

            # Generation for BLEU/ChrF - only process if we still need samples
            if sample_count < max_samples:
                if use_autoregressive:
                    # True autoregressive generation including encoding
                    raw_model = model.module if hasattr(model, "module") else model
                    enc = raw_model.encode(src)
                    generated_ids = raw_model.generate(
                        src,
                        max_len=model_cfg.max_len,
                        enc_output=enc,
                        bos_id=model_cfg.bos_id,
                        eos_id=model_cfg.eos_id,
                    )
                else:
                    # Teacher-forced predictions (fastest, uses existing logits)
                    generated_ids = preds

                for i in range(src.size(0)):
                    if sample_count >= max_samples:
                        break
                    # Post-process: stop at EOS or PAD tokens
                    ids = generated_ids[i].tolist()
                    # Find first EOS or PAD token and truncate
                    for idx, token_id in enumerate(ids):
                        if token_id == model_cfg.eos_id or token_id == model_cfg.pad_id:
                            ids = ids[:idx]
                            break
                    hyp = tgt_sp.decode(ids)
                    ref = tgt_sp.decode(tgt[i].tolist())
                    hypotheses.append(hyp)
                    references.append(ref)
                    sample_count += 1

    avg_loss = total_loss_sum / max(1, total_tokens)
    ppl = math.exp(min(avg_loss, 100))
    acc = correct_tokens / max(1, total_tokens)

    bleu = sacrebleu.corpus_bleu(hypotheses, [references]).score
    chrf = sacrebleu.corpus_chrf(hypotheses, [references]).score

    metrics = {"loss": avg_loss, "ppl": ppl, "acc": acc, "bleu": bleu, "chrf": chrf}

    print(
        f"\n{get_time_info()} [Validation] Loss: {avg_loss:.4f} | PPL: {ppl:.2f} | Acc: {acc:.4f} | BLEU: {bleu:.2f} | ChrF: {chrf:.2f}"
    )
    for i in range(min(train_cfg.quick_test_samples, len(hypotheses))):
        print(f"Sample {i}:")
        print(f"  Ref: {references[i]}")
        print(f"  Hyp: {hypotheses[i]}")
    print("-" * 30)

    model.train()
    return metrics