def convert_to_ct2_cli(experiment_dir: str, **kwargs):
"""
Convert a trained model to CTranslate2 format.
Args:
experiment_dir: Path to experiment directory
**kwargs: Overrides for configuration parameters
"""
model_cfg, data_cfg, train_cfg, export_cfg = load_config(
os.path.join(experiment_dir, "config.yaml")
)
# Apply overrides
for key, value in kwargs.items():
found = False
for cfg in [model_cfg, data_cfg, train_cfg, export_cfg]:
if hasattr(cfg, key):
setattr(cfg, key, value)
found = True
if not found:
print(f"Warning: Configuration key '{key}' not found in any config object.")
model_file = os.path.join(experiment_dir, "averaged_model.safetensors")
if not os.path.exists(model_file):
raise FileNotFoundError(f"Model file not found at {model_file}")
state_dict = load_file(model_file, device="cpu")
# Strip _orig_mod. prefix if present (from torch.compile)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_key = k.replace("_orig_mod.", "")
new_state_dict[new_key] = v
state_dict = new_state_dict
activation_map = {
"gelu": ctranslate2.specs.Activation.GELU,
"relu": ctranslate2.specs.Activation.RELU,
"swish": ctranslate2.specs.Activation.SWISH,
"silu": ctranslate2.specs.Activation.SWISH,
}
ct2_activation = activation_map.get(
model_cfg.activation, ctranslate2.specs.Activation.GELU
)
is_gated = getattr(model_cfg, "mlp_type", "standard") == "gated"
use_rms_norm = getattr(model_cfg, "norm_type", "layernorm") == "rmsnorm"
tie_decoder_embeddings = getattr(model_cfg, "tie_decoder_embeddings", False)
encoder_spec = ctranslate2.specs.TransformerEncoderSpec(
num_layers=model_cfg.enc_layers,
num_heads=model_cfg.n_heads,
pre_norm=True,
activation=ct2_activation,
ffn_glu=is_gated,
rms_norm=use_rms_norm,
)
decoder_spec = ctranslate2.specs.TransformerDecoderSpec(
num_layers=model_cfg.dec_layers,
num_heads=model_cfg.n_heads,
pre_norm=True,
activation=ct2_activation,
ffn_glu=is_gated,
rms_norm=use_rms_norm,
)
# ... mapping ...
# Embeddings
src_emb = state_dict.get("src_tok_emb.embedding.weight")
if src_emb is not None:
encoder_spec.embeddings[0].weight = (
src_emb.detach().float().cpu().numpy()
if hasattr(src_emb, "detach")
else src_emb.numpy()
)
tgt_emb = state_dict.get("tgt_tok_emb.embedding.weight")
if tgt_emb is None and tie_decoder_embeddings:
tgt_emb = state_dict.get("generator.weight")
if tgt_emb is not None:
decoder_spec.embeddings.weight = (
tgt_emb.detach().float().cpu().numpy()
if hasattr(tgt_emb, "detach")
else tgt_emb.numpy()
)
# Position Encodings
pe_tensor = state_dict.get("positional_encoding.pe")
if pe_tensor is not None:
pe = (
pe_tensor[0].detach().float().cpu().numpy()
if hasattr(pe_tensor, "detach")
else pe_tensor[0].numpy()
)
encoder_spec.position_encodings.encodings = pe
decoder_spec.position_encodings.encodings = pe
# Generator (Projection)
if tie_decoder_embeddings:
decoder_spec.projection.weight = decoder_spec.embeddings.weight
_, gen_bias = get_layer_weights(state_dict, "generator")
if gen_bias is not None:
decoder_spec.projection.bias = gen_bias
else:
decoder_spec.projection.bias = np.zeros(
decoder_spec.embeddings.weight.shape[0], dtype=np.float32
)
else:
set_linear(decoder_spec.projection, state_dict, "generator")
# 4. Encoder Layers
for i in range(model_cfg.enc_layers):
prefix = f"encoder.layers.{i}"
layer_spec = encoder_spec.layer[i]
set_multihead_attention(
layer_spec.self_attention,
state_dict,
f"{prefix}.self_attn",
self_attention=True,
)
set_layer_norm(
layer_spec.self_attention.layer_norm, state_dict, f"{prefix}.norm1"
)
if is_gated:
# gate_up_proj is fused [gate, up]
weight, bias = get_layer_weights(state_dict, f"{prefix}.ffn.gate_up_proj")
gate_w, up_w = np.split(weight, 2, axis=0)
layer_spec.ffn.linear_0.weight = gate_w
layer_spec.ffn.linear_0_noact.weight = up_w
if bias is not None:
gate_b, up_b = np.split(bias, 2)
layer_spec.ffn.linear_0.bias = gate_b
layer_spec.ffn.linear_0_noact.bias = up_b
else:
layer_spec.ffn.linear_0.bias = np.zeros(
gate_w.shape[0], dtype=np.float32
)
layer_spec.ffn.linear_0_noact.bias = np.zeros(
up_w.shape[0], dtype=np.float32
)
set_linear(layer_spec.ffn.linear_1, state_dict, f"{prefix}.ffn.down_proj")
else:
set_linear(layer_spec.ffn.linear_0, state_dict, f"{prefix}.ffn.linear1")
set_linear(layer_spec.ffn.linear_1, state_dict, f"{prefix}.ffn.linear2")
set_layer_norm(layer_spec.ffn.layer_norm, state_dict, f"{prefix}.norm2")
# Final Encoder Norm
set_layer_norm(encoder_spec.layer_norm, state_dict, "encoder.norm")
# 5. Decoder Layers
for i in range(model_cfg.dec_layers):
prefix = f"decoder.layers.{i}"
layer_spec = decoder_spec.layer[i]
set_multihead_attention(
layer_spec.self_attention,
state_dict,
f"{prefix}.self_attn",
self_attention=True,
)
set_layer_norm(
layer_spec.self_attention.layer_norm, state_dict, f"{prefix}.norm1"
)
set_multihead_attention(
layer_spec.attention,
state_dict,
f"{prefix}.multihead_attn",
self_attention=False,
)
set_layer_norm(layer_spec.attention.layer_norm, state_dict, f"{prefix}.norm2")
if is_gated:
# gate_up_proj is fused [gate, up]
weight, bias = get_layer_weights(state_dict, f"{prefix}.ffn.gate_up_proj")
gate_w, up_w = np.split(weight, 2, axis=0)
layer_spec.ffn.linear_0.weight = gate_w
layer_spec.ffn.linear_0_noact.weight = up_w
if bias is not None:
gate_b, up_b = np.split(bias, 2)
layer_spec.ffn.linear_0.bias = gate_b
layer_spec.ffn.linear_0_noact.bias = up_b
else:
layer_spec.ffn.linear_0.bias = np.zeros(
gate_w.shape[0], dtype=np.float32
)
layer_spec.ffn.linear_0_noact.bias = np.zeros(
up_w.shape[0], dtype=np.float32
)
set_linear(layer_spec.ffn.linear_1, state_dict, f"{prefix}.ffn.down_proj")
else:
set_linear(layer_spec.ffn.linear_0, state_dict, f"{prefix}.ffn.linear1")
set_linear(layer_spec.ffn.linear_1, state_dict, f"{prefix}.ffn.linear2")
set_layer_norm(layer_spec.ffn.layer_norm, state_dict, f"{prefix}.norm3")
# Final Decoder Norm
set_layer_norm(decoder_spec.layer_norm, state_dict, "decoder.norm")
# 6. Save model
if not os.path.exists(export_cfg.output_dir):
os.makedirs(export_cfg.output_dir)
spec = ctranslate2.specs.TransformerSpec(encoder_spec, decoder_spec)
spec.config.add_source_bos = export_cfg.add_source_bos # type: ignore
spec.config.add_source_eos = export_cfg.add_source_eos # type: ignore
# Register vocabularies
spec.register_source_vocabulary(
convert_vocab(f"{data_cfg.tokenizer_prefix_src}.vocab")
)
spec.register_target_vocabulary(
convert_vocab(f"{data_cfg.tokenizer_prefix_tgt}.vocab")
)
spec.validate()
spec.optimize(quantization=export_cfg.quantization)
spec.save(export_cfg.output_dir)
print(f"Model saved to {export_cfg.output_dir}")
# Copy Tokenizers to output directory
shutil.copy(
f"{data_cfg.tokenizer_prefix_src}.model",
Path(export_cfg.output_dir) / "src.spm.model",
)
shutil.copy(
f"{data_cfg.tokenizer_prefix_tgt}.model",
Path(export_cfg.output_dir) / "tgt.spm.model",
)