Skip to content

Checkpoint Averaging

average_checkpoints

average_checkpoints_cli(experiment_dir, **kwargs)

Average the last k checkpoints and save as safetensors/INT8.

Parameters:

Name Type Description Default
experiment_dir str

Path to experiment directory

required
**kwargs

Overrides for configuration parameters

{}
Source code in src/quickmt_train/average_checkpoints.py
def average_checkpoints_cli(experiment_dir: str, **kwargs):
    """
    Average the last k checkpoints and save as safetensors/INT8.

    Args:
        experiment_dir: Path to experiment directory
        **kwargs: Overrides for configuration parameters
    """
    model_cfg, data_cfg, train_cfg, export_cfg = load_config(
        os.path.join(experiment_dir, "config.yaml")
    )

    # Apply overrides
    for key, value in kwargs.items():
        found = False
        for cfg in [model_cfg, data_cfg, train_cfg, export_cfg]:
            if hasattr(cfg, key):
                setattr(cfg, key, value)
                found = True
        if not found:
            print(f"Warning: Configuration key '{key}' not found in any config object.")

    # 1. Find the best k models based on validation perplexity
    metrics_path = os.path.join(experiment_dir, "metrics.jsonl")
    if os.path.exists(metrics_path):
        print(f"Reading metrics from {metrics_path}")
        metrics = []
        with open(metrics_path, "r") as f:
            for line in f:
                metrics.append(json.loads(line))

        # Sort by perplexity (ppl) ascending (lower is better)
        metrics.sort(key=lambda x: x.get("ppl", float("inf")))

        best_steps = [m["step"] for m in metrics]
        selected = [f"model_{step}.safetensors" for step in best_steps]

        # Verify files exist
        selected = [
            f
            for f in selected
            if os.path.exists(os.path.join(train_cfg.checkpoint_dir, f))
        ]
        selected = selected[: export_cfg.k]
        print(f"Selected {len(selected)} best checkpoints based on PPL.")
    else:
        print(
            f"Metrics file {metrics_path} not found. Falling back to last k checkpoints."
        )
        if not os.path.exists(train_cfg.checkpoint_dir):
            print(f"Directory {train_cfg.checkpoint_dir} not found.")
            return

        checkpoints = [
            f
            for f in os.listdir(train_cfg.checkpoint_dir)
            if f.startswith("model_")
            and f.endswith(".safetensors")
            and "_int8" not in f
        ]

        # Sort by step number
        checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]), reverse=True)
        selected = checkpoints[: export_cfg.k]

    if not selected:
        print("No model files found.")
        return

    print(f"Averaging {len(selected)} model checkpoints:")
    for c in selected:
        print(f" - {c}")

    # 2. Load and average state dicts
    avg_state_dict: dict[str, torch.Tensor] = {}
    count = len(selected)

    for i, ckpt_name in enumerate(selected):
        ckpt_path = os.path.join(train_cfg.checkpoint_dir, ckpt_name)
        clean_state_dict = load_file(ckpt_path, device="cpu")

        if not avg_state_dict:
            avg_state_dict = clean_state_dict
        else:
            for k in clean_state_dict:
                if k in avg_state_dict:
                    avg_state_dict[k] += clean_state_dict[k]
                else:
                    # This might happen if mixing different architectures
                    print(f"Warning: Key {k} not found in first checkpoint. Skipping.")

    # Divide by count
    for k in avg_state_dict:
        # Only divide floating point tensors
        if avg_state_dict[k].is_floating_point():
            avg_state_dict[k] = avg_state_dict[k] / count
        else:
            avg_state_dict[k] = torch.div(
                avg_state_dict[k], count, rounding_mode="floor"
            )

    # 3. Save as .pt and .safetensors (FP32/Averaged weights)
    # Ensure experiment directory exists
    os.makedirs(train_cfg.experiment_name, exist_ok=True)

    pt_output = f"{export_cfg.output_prefix}.pt"
    torch.save({"model_state_dict": avg_state_dict}, pt_output)

    st_output = f"{export_cfg.output_prefix}.safetensors"
    # Create model to handle shared tensors correctly during save
    model = Seq2SeqTransformer(model_cfg).to("cpu")
    model.load_state_dict(avg_state_dict, strict=False)
    save_model(model, st_output)
    print(f"Saved averaged model to {pt_output} and {st_output}")

    # 4. Calibration and INT8 Export
    if export_cfg.export_int8:
        print("\nStarting re-calibration for INT8 export...")

        # Override settings for calibration
        data_cfg.max_tokens_per_batch = 2048
        data_cfg.buffer_size = 10000
        data_cfg.num_workers = 0
        _, dev_loader, _, _ = PrepareData(model_cfg, data_cfg, train_cfg)

        model = Seq2SeqTransformer(model_cfg).to("cpu")

        # Load averaged weights BEFORE preparing for quantization
        print("Loading averaged weights...")
        model.load_state_dict(avg_state_dict, strict=False)
        model.eval()

        # Prepare model for Post-Training Quantization (PTQ)
        print("Preparing model for Post-Training Quantization (PTQ)...")
        # Set quantization config
        if export_cfg.qconfig_backend == "fbgemm":
            # Use a more robust qconfig for x86
            model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
            # Histogram observer is generally better for activations in Transformers
            model.qconfig.activation = (
                torch.ao.quantization.HistogramObserver.with_args(reduce_range=True)
            )
        else:
            model.qconfig = torch.ao.quantization.get_default_qconfig(
                export_cfg.qconfig_backend
            )

        # Disable quantization for Embedding
        for name, module in model.named_modules():
            if any(k in name for k in ["self_attn", "multihead_attn", "emb"]):
                module.qconfig = None  # type: ignore

        # Prepare the model (inserts observers)
        torch.ao.quantization.prepare(model, inplace=True)

        # Calibrate
        model.calibrate(dev_loader, num_batches=export_cfg.calib_batches)

        # Convert and Save
        model.convert_to_int8()
        int8_state_dict = model.state_dict()
        int8_output = f"{export_cfg.output_prefix}_int8.pt"
        torch.save({"model_state_dict": int8_state_dict}, int8_output)
        print(f"Saved calibrated INT8 model to {int8_output}")

        # Also save as safetensors (dequantized) for easier loading
        st_int8_output = f"{export_cfg.output_prefix}_int8.safetensors"
        dequantized_state_dict = {}
        for k, v in int8_state_dict.items():
            # Handle packed params
            if k.endswith("._packed_params._packed_params") and isinstance(v, tuple):
                # This is a bit tricky for safetensors as we need to map it back to .weight and .bias
                prefix = k.replace("._packed_params._packed_params", "")
                qweight, bias = v
                dequantized_state_dict[f"{prefix}.weight"] = (
                    qweight.dequantize() if hasattr(qweight, "dequantize") else qweight
                )
                if bias is not None:
                    dequantized_state_dict[f"{prefix}.bias"] = bias
            elif hasattr(v, "dequantize"):
                dequantized_state_dict[k] = v.dequantize()
            else:
                dequantized_state_dict[k] = v

        # Remove scale and zero_point if they exist as they are now baked into the dequantized weights
        # Also remove any non-tensor values (like .dtype) which safetensors doesn't support
        keys_to_remove = [
            k
            for k, v in dequantized_state_dict.items()
            if k.endswith(".scale")
            or k.endswith(".zero_point")
            or not isinstance(v, torch.Tensor)
        ]
        for k in keys_to_remove:
            del dequantized_state_dict[k]

        temp_model = Seq2SeqTransformer(model_cfg).to("cpu")
        temp_model.load_state_dict(dequantized_state_dict, strict=False)
        save_model(temp_model, st_int8_output)
        print(f"Saved dequantized INT8 model to {st_int8_output}")