""" GLiNER2 Extractor Model with Optimized Batch Processing This module contains the core Extractor model that accepts PreprocessedBatch directly for efficient GPU-only forward passes. """ import os import tempfile from typing import Dict, List, Any, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from gliner.modeling.span_rep import SpanRepLayer from gliner2.layers import CountLSTMoE, CountLSTM, create_mlp, CountLSTMv2 from gliner2.processor import SchemaTransformer, PreprocessedBatch, SamplingConfig from safetensors.torch import save_file, load_file from transformers import ( PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig, AutoTokenizer, ) class ExtractorConfig(PretrainedConfig): """Configuration for the Extractor model.""" model_type = "extractor" def __init__( self, model_name: str = "bert-base-uncased", max_width: int = 8, counting_layer: str = "count_lstm", token_pooling: str = "first", **kwargs ): super().__init__(**kwargs) self.model_name = model_name self.max_width = max_width self.counting_layer = counting_layer self.token_pooling = token_pooling class Extractor(PreTrainedModel): """ GLiNER2 Extractor Model. This model accepts PreprocessedBatch for efficient training. Use processor.collate_fn_train() to create batches. Example: >>> processor = SchemaTransformer(model_name) >>> model = Extractor.from_pretrained(repo_id) >>> >>> # Training >>> loader = DataLoader(dataset, collate_fn=processor.collate_fn_train) >>> for batch in loader: ... batch = batch.to(device) ... loss = model(batch)["total_loss"] """ config_class = ExtractorConfig def __init__(self, config: ExtractorConfig, encoder_config=None, tokenizer=None): super().__init__(config) self.config = config self.max_width = config.max_width # Initialize processor if tokenizer is not None: self.processor = SchemaTransformer( tokenizer=tokenizer, token_pooling=config.token_pooling ) else: self.processor = SchemaTransformer( config.model_name, token_pooling=config.token_pooling ) # Load encoder if encoder_config is not None: self.encoder = AutoModel.from_config(encoder_config, trust_remote_code=True) else: self.encoder = AutoModel.from_pretrained(config.model_name, trust_remote_code=True) self.encoder.resize_token_embeddings(len(self.processor.tokenizer)) self.hidden_size = self.encoder.config.hidden_size # Span representation layer self.span_rep = SpanRepLayer( span_mode="markerV0", hidden_size=self.hidden_size, max_width=self.max_width, dropout=0.1, ) # Classifier for classification tasks self.classifier = create_mlp( input_dim=self.hidden_size, intermediate_dims=[self.hidden_size * 2], output_dim=1, dropout=0., activation="relu", add_layer_norm=False ) # Count prediction layer self.count_pred = create_mlp( input_dim=self.hidden_size, intermediate_dims=[self.hidden_size * 2], output_dim=20, dropout=0., activation="relu", add_layer_norm=False ) # Count embedding module if config.counting_layer == "count_lstm": self.count_embed = CountLSTM(self.hidden_size) elif config.counting_layer == "count_lstm_moe": self.count_embed = CountLSTMoE( hidden_size=self.hidden_size, n_experts=4, ffn_mult=2, dropout=0.1 ) elif config.counting_layer == "count_lstm_v2": self.count_embed = CountLSTMv2(hidden_size=self.hidden_size) # LoRA adapter state self._lora_layers = {} self._adapter_config = None self._print_config(config) def _print_config(self, config): print("=" * 60) print("🧠 Model Configuration") print("=" * 60) print(f"Encoder model : {config.model_name}") print(f"Counting layer : {config.counting_layer}") print(f"Token pooling : {config.token_pooling}") print("=" * 60) # ========================================================================= # Main Forward Pass # ========================================================================= def forward( self, batch: PreprocessedBatch, return_individual_losses: bool = False ) -> Dict[str, torch.Tensor]: """ Forward pass on preprocessed batch. Args: batch: PreprocessedBatch from processor.collate_fn_train() return_individual_losses: If True, return per-sample losses Returns: Dict with: - total_loss: Sum of all losses - classification_loss: Classification task loss - structure_loss: Span extraction loss - count_loss: Count prediction loss - batch_size: Number of valid samples """ if len(batch) == 0: return self._empty_loss_dict() device = next(self.parameters()).device batch = batch.to(device) # Encode batch through transformer all_token_embs, all_schema_embs = self._encode_batch(batch) # Compute losses for each sample cls_losses = [] struct_losses = [] count_losses = [] individual = [] valid_samples = 0 for i in range(len(batch)): try: sample_losses = self._compute_sample_loss( token_embeddings=all_token_embs[i], embs_per_schema=all_schema_embs[i], task_types=batch.task_types[i], structure_labels=batch.structure_labels[i], device=device ) cls_losses.append(sample_losses["classification"]) struct_losses.append(sample_losses["structure"]) count_losses.append(sample_losses["count"]) if return_individual_losses: individual.append({ "total_loss": ( sample_losses["classification"] + sample_losses["structure"] + sample_losses["count"] ).item(), "classification_loss": sample_losses["classification"].item(), "structure_loss": sample_losses["structure"].item(), "count_loss": sample_losses["count"].item(), }) valid_samples += 1 except Exception as e: print(f"Error processing sample {i}: {e}") zero = torch.tensor(0.0, device=device) cls_losses.append(zero) struct_losses.append(zero) count_losses.append(zero) if return_individual_losses: individual.append({ "total_loss": 0.0, "classification_loss": 0.0, "structure_loss": 0.0, "count_loss": 0.0, "error": str(e) }) if valid_samples == 0: result = self._empty_loss_dict() if return_individual_losses: result["individual_losses"] = individual return result # Aggregate losses total_cls = torch.stack(cls_losses).sum() total_struct = torch.stack(struct_losses).sum() total_count = torch.stack(count_losses).sum() total_loss = total_cls + total_struct + total_count result = { "total_loss": total_loss, "classification_loss": total_cls, "structure_loss": total_struct, "count_loss": total_count, "batch_size": valid_samples } if return_individual_losses: result["individual_losses"] = individual return result def _empty_loss_dict(self) -> Dict[str, torch.Tensor]: """Return empty loss dictionary.""" device = next(self.parameters()).device return { "total_loss": torch.tensor(0.0, device=device, requires_grad=True), "classification_loss": torch.tensor(0.0, device=device), "structure_loss": torch.tensor(0.0, device=device), "count_loss": torch.tensor(0.0, device=device), "batch_size": 0 } # ========================================================================= # Encoding # ========================================================================= def _encode_batch( self, batch: PreprocessedBatch ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: """ Encode batch through transformer and extract embeddings. Args: batch: PreprocessedBatch with input_ids and attention_mask Returns: - all_token_embs: List of (text_len, hidden) per sample - all_schema_embs: List of schema embeddings per sample """ # Forward through encoder outputs = self.encoder( input_ids=batch.input_ids, attention_mask=batch.attention_mask ) token_embeddings = outputs.last_hidden_state # Extract embeddings using processor return self.processor.extract_embeddings_from_batch( token_embeddings, batch.input_ids, batch ) # ========================================================================= # Loss Computation # ========================================================================= def _compute_sample_loss( self, token_embeddings: torch.Tensor, embs_per_schema: List[List[torch.Tensor]], task_types: List[str], structure_labels: List[Any], device: torch.device ) -> Dict[str, torch.Tensor]: """ Compute all losses for a single sample. Args: token_embeddings: (text_len, hidden) text token embeddings embs_per_schema: List of schema embeddings task_types: Task type for each schema structure_labels: Labels for each schema device: Computation device Returns: Dict with classification, structure, and count losses """ cls_loss = torch.tensor(0.0, device=device) struct_loss = torch.tensor(0.0, device=device) count_loss = torch.tensor(0.0, device=device) # Compute span representations if needed has_span_task = any(t != "classifications" for t in task_types) span_info = None if has_span_task and token_embeddings.numel() > 0: span_info = self.compute_span_rep(token_embeddings) all_counts = [] all_p_embs = [] for i, task_type in enumerate(task_types): if not embs_per_schema[i]: continue schema_emb = torch.stack(embs_per_schema[i]) if task_type == "classifications": # Classification loss cls_embeds = schema_emb[1:] # Skip [P] token logits = self.classifier(cls_embeds).squeeze(-1) labels = torch.tensor(structure_labels[i], dtype=torch.float, device=device) cls_loss = cls_loss + F.binary_cross_entropy_with_logits( logits, labels, reduction="sum" ) else: # Structure loss structure = structure_labels[i] if structure[0] == 0: # No instances to extract continue if span_info is not None: struct_loss = struct_loss + self.compute_struct_loss( span_info["span_rep"], schema_emb, structure, span_info["span_mask"] ) # Collect for count loss (skip entities) if task_type != "entities": all_counts.append(min(structure[0], 19)) all_p_embs.append(schema_emb[0]) # Count loss if all_counts and all_p_embs: counts = torch.tensor(all_counts, dtype=torch.long, device=device) p_embs = torch.stack(all_p_embs) count_loss = F.cross_entropy(self.count_pred(p_embs), counts, reduction="sum") return { "classification": cls_loss, "structure": struct_loss, "count": count_loss } # ========================================================================= # Span Representation # ========================================================================= def compute_span_rep(self, token_embeddings: torch.Tensor) -> Dict[str, Any]: """ Compute span representations for token embeddings. Args: token_embeddings: (text_len, hidden) token embeddings Returns: Dict with span_rep, spans_idx, and span_mask """ text_length = len(token_embeddings) device = token_embeddings.device spans_idx = [] for i in range(text_length): for j in range(self.max_width): if i + j < text_length: spans_idx.append((i, i + j)) else: spans_idx.append((-1, -1)) spans_idx = torch.tensor([spans_idx], dtype=torch.long, device=device) # Mask invalid spans span_mask = (spans_idx[:, :, 0] == -1) | (spans_idx[:, :, 1] == -1) # Replace invalid with (0, 0) for safe indexing safe_spans = torch.where( span_mask.unsqueeze(-1), torch.zeros_like(spans_idx), spans_idx ) # Compute span representations span_rep = self.span_rep( token_embeddings.unsqueeze(0), safe_spans ).squeeze(0) return { "span_rep": span_rep, "spans_idx": spans_idx, "span_mask": span_mask } def compute_struct_loss( self, span_rep: torch.Tensor, schema_emb: torch.Tensor, structure: List[Any], span_mask: torch.Tensor, masking_rate: float = 0.5 ) -> torch.Tensor: """ Compute structure extraction loss with negative span masking. Args: span_rep: (num_spans, hidden) span representations schema_emb: (num_fields + 1, hidden) schema embeddings structure: [count, spans] structure labels span_mask: (1, num_spans) mask for invalid spans masking_rate: Probability of masking negative spans Returns: Structure loss tensor """ gold_count = min(structure[0], 19) struct_proj = self.count_embed(schema_emb[1:], gold_count) scores = torch.einsum('lkd,bpd->bplk', span_rep, struct_proj) # Create label tensor labs = torch.zeros_like(scores) for i in range(gold_count): gold_spans = structure[1][i] for k, span in enumerate(gold_spans): if span is None or span == (-1, -1): continue if isinstance(span, tuple): start, end = span width = end - start if 0 <= start < scores.shape[2] and 0 <= width < scores.shape[3]: labs[i, k, start, width] = 1 elif isinstance(span, list): for sub in span: if sub is None or sub == (-1, -1): continue start, end = sub width = end - start if 0 <= start < scores.shape[2] and 0 <= width < scores.shape[3]: labs[i, k, start, width] = 1 # Apply negative masking if masking_rate > 0.0 and self.training: negative = (labs == 0) random_mask = torch.rand_like(scores) < masking_rate to_mask = negative & random_mask loss_mask = (~to_mask).float() else: loss_mask = torch.ones_like(scores) # Compute masked loss loss = F.binary_cross_entropy_with_logits(scores, labs, reduction="none") loss = loss * loss_mask loss = loss.view(loss.shape[0], loss.shape[1], -1) * (~span_mask[0]).float() return loss.sum() # ========================================================================= # Hugging Face Methods # ========================================================================= def push_to_hub(self, repo_id: str, private: bool = True): """Push model to Hugging Face Hub.""" with tempfile.TemporaryDirectory() as tmp_dir: self.save_pretrained(tmp_dir) super().push_to_hub(repo_id=repo_id, save_dir=tmp_dir, private=private) self.processor.tokenizer.push_to_hub(repo_id) @classmethod def from_pretrained(cls, repo_or_dir: str, **kwargs): """ Load model from Hugging Face Hub or local directory. To use a LoRA adapter: 1. Load the base model first 2. Then load the adapter using model.load_adapter() Example: model = Extractor.from_pretrained("base-model-name") model.load_adapter("path/to/adapter") """ from huggingface_hub import hf_hub_download def download_or_local(repo, filename): if os.path.isdir(repo): return os.path.join(repo, filename) return hf_hub_download(repo, filename) config_path = download_or_local(repo_or_dir, "config.json") config = cls.config_class.from_pretrained(config_path) encoder_config_path = download_or_local(repo_or_dir, "encoder_config/config.json") encoder_config = AutoConfig.from_pretrained(encoder_config_path) tokenizer = AutoTokenizer.from_pretrained(repo_or_dir) model = cls(config, encoder_config=encoder_config, tokenizer=tokenizer) # Load weights try: model_path = download_or_local(repo_or_dir, "model.safetensors") state_dict = load_file(model_path) except Exception: model_path = download_or_local(repo_or_dir, "pytorch_model.bin") state_dict = torch.load(model_path, map_location="cpu") # Handle embedding size mismatch try: saved_emb = state_dict["encoder.embeddings.word_embeddings.weight"] model_emb = model.encoder.embeddings.word_embeddings.weight if saved_emb.shape[0] != model_emb.shape[0]: extra = model_emb.shape[0] - saved_emb.shape[0] state_dict["encoder.embeddings.word_embeddings.weight"] = torch.cat([ saved_emb, torch.randn(extra, saved_emb.shape[1]) * 0.02 ], dim=0) except KeyError: pass model.load_state_dict(state_dict) return model def load_adapter(self, adapter_path: str) -> 'Extractor': """ Load a LoRA adapter onto this model. If an adapter is already loaded, it will be unloaded first. Args: adapter_path: Path to adapter directory Returns: self for method chaining Example: model.load_adapter("./legal_adapter") results = model.extract_entities(text, entities) """ from gliner2.training.lora import load_lora_adapter, LoRAAdapterConfig # Load adapter config config = LoRAAdapterConfig.load(adapter_path) self._lora_layers = load_lora_adapter(self, adapter_path, auto_unload=True) self._adapter_config = config return self def unload_adapter(self) -> 'Extractor': """ Unload current LoRA adapter, restoring base model. Returns: self for method chaining """ from gliner2.training.lora import unload_lora_adapter if self._lora_layers: unload_lora_adapter(self) self._lora_layers = {} self._adapter_config = None return self def merge_lora(self) -> 'Extractor': """ Merge LoRA weights into base model and remove adapter structure. After calling this, the model will have standard Linear layers with merged weights. LoRA adapters are permanently removed. Returns: self for method chaining Raises: ValueError: If no adapter is loaded Example: model.load_adapter("./my_adapter") model.merge_lora() # Now model has merged weights, no LoRA model.save_pretrained("./merged_model") """ if not self._lora_layers: raise ValueError("No adapter loaded. Nothing to merge.") from gliner2.training.lora import merge_lora_weights merge_lora_weights(self) self._lora_layers = {} self._adapter_config = None return self def save_adapter(self, save_path: str) -> None: """ Save only the LoRA adapter (not full model). Args: save_path: Directory to save adapter Raises: ValueError: If no adapter is loaded """ if not self._lora_layers: raise ValueError("No adapter loaded. Use save_pretrained for full model.") from gliner2.training.lora import save_lora_adapter save_lora_adapter(self, save_path) @property def has_adapter(self) -> bool: """Check if an adapter is currently loaded.""" return bool(self._lora_layers) @property def adapter_config(self): """Get config of loaded adapter, or None.""" return self._adapter_config def save_pretrained( self, save_directory: str, save_adapter_only: bool = False, merge_lora: bool = True, **kwargs ): """ Save model to directory. Args: save_directory: Where to save save_adapter_only: If True and adapter loaded, save only adapter merge_lora: If True and LoRA active, merge LoRA weights into base model and remove adapter structure before saving. WARNING: This permanently removes LoRA from the model instance. """ if save_adapter_only: if not self._lora_layers: raise ValueError("save_adapter_only=True but no adapter loaded") self.save_adapter(save_directory) return # Handle LoRA merging if requested if merge_lora and self._lora_layers: self.merge_lora() # Original save logic os.makedirs(save_directory, exist_ok=True) self.config.save_pretrained(save_directory) encoder_config_path = os.path.join(save_directory, "encoder_config") os.makedirs(encoder_config_path, exist_ok=True) self.encoder.config.save_pretrained(encoder_config_path) model_path = os.path.join(save_directory, "model.safetensors") save_file(self.state_dict(), model_path) self.processor.tokenizer.save_pretrained(save_directory)