""" GLiNER2 - Advanced Information Extraction Engine This module provides the main GLiNER2 class with optimized batch processing using DataLoader-based parallel preprocessing. Example: >>> from gliner2 import GLiNER2 >>> >>> extractor = GLiNER2.from_pretrained("model-repo") >>> >>> # Simple extraction >>> results = extractor.extract_entities( ... "Apple released iPhone 15.", ... ["company", "product"] ... ) >>> >>> # Batch extraction (parallel preprocessing) >>> results = extractor.batch_extract_entities( ... texts_list, ... ["company", "product"], ... batch_size=32, ... num_workers=4 ... ) """ from __future__ import annotations import re import hashlib import json from collections import OrderedDict from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Union, Tuple, TYPE_CHECKING, Pattern, Literal import torch import torch.nn.functional as F from torch.utils.data import DataLoader from gliner2.model import Extractor from gliner2.processor import PreprocessedBatch from gliner2.training.trainer import ExtractorCollator if TYPE_CHECKING: from gliner2.api_client import GLiNER2API # ============================================================================= # Validators # ============================================================================= @dataclass class RegexValidator: """Regex-based span filter for post-processing.""" pattern: str | Pattern[str] mode: Literal["full", "partial"] = "full" exclude: bool = False flags: int = re.IGNORECASE _compiled: Pattern[str] = field(init=False, repr=False) def __post_init__(self): if self.mode not in {"full", "partial"}: raise ValueError(f"mode must be 'full' or 'partial', got {self.mode!r}") try: compiled = ( self.pattern if isinstance(self.pattern, re.Pattern) else re.compile(self.pattern, self.flags) ) except re.error as err: raise ValueError(f"Invalid regex: {self.pattern!r}") from err object.__setattr__(self, "_compiled", compiled) def __call__(self, text: str) -> bool: return self.validate(text) def validate(self, text: str) -> bool: matcher = self._compiled.fullmatch if self.mode == "full" else self._compiled.search matched = matcher(text) is not None return not matched if self.exclude else matched # ============================================================================= # Schema Builder # ============================================================================= class StructureBuilder: """Builder for structured data schemas.""" def __init__(self, schema: 'Schema', parent: str): self.schema = schema self.parent = parent self.fields = OrderedDict() self.descriptions = OrderedDict() self.field_order = [] self._finished = False def field( self, name: str, dtype: Literal["str", "list"] = "list", choices: Optional[List[str]] = None, description: Optional[str] = None, threshold: Optional[float] = None, validators: Optional[List[RegexValidator]] = None ) -> 'StructureBuilder': """Add a field to the structure.""" self.fields[name] = {"value": "", "choices": choices} if choices else "" self.field_order.append(name) if description: self.descriptions[name] = description self.schema._store_field_metadata(self.parent, name, dtype, threshold, choices, validators) return self def _auto_finish(self): if not self._finished: self.schema._store_field_order(self.parent, self.field_order) self.schema.schema["json_structures"].append({self.parent: self.fields}) if self.descriptions: if "json_descriptions" not in self.schema.schema: self.schema.schema["json_descriptions"] = {} self.schema.schema["json_descriptions"][self.parent] = self.descriptions self._finished = True def __getattr__(self, name): if hasattr(self.schema, name): self._auto_finish() return getattr(self.schema, name) raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") class Schema: """Schema builder for extraction tasks.""" def __init__(self): self.schema = { "json_structures": [], "classifications": [], "entities": OrderedDict(), "relations": [], "json_descriptions": {}, "entity_descriptions": OrderedDict() } self._field_metadata = {} self._entity_metadata = {} self._relation_metadata = {} self._field_orders = {} self._entity_order = [] self._relation_order = [] self._active_builder = None def _store_field_metadata(self, parent, field, dtype, threshold, choices, validators=None): if threshold is not None and not 0 <= threshold <= 1: raise ValueError(f"Threshold must be 0-1, got {threshold}") self._field_metadata[f"{parent}.{field}"] = { "dtype": dtype, "threshold": threshold, "choices": choices, "validators": validators or [] } def _store_entity_metadata(self, entity, dtype, threshold): if threshold is not None and not 0 <= threshold <= 1: raise ValueError(f"Threshold must be 0-1, got {threshold}") self._entity_metadata[entity] = {"dtype": dtype, "threshold": threshold} def _store_field_order(self, parent, order): self._field_orders[parent] = order def structure(self, name: str) -> StructureBuilder: """Start building a structure schema.""" if self._active_builder: self._active_builder._auto_finish() self._active_builder = StructureBuilder(self, name) return self._active_builder def classification( self, task: str, labels: Union[List[str], Dict[str, str]], multi_label: bool = False, cls_threshold: float = 0.5, **kwargs ) -> 'Schema': """Add classification task.""" if self._active_builder: self._active_builder._auto_finish() self._active_builder = None label_names = list(labels.keys()) if isinstance(labels, dict) else labels label_descs = labels if isinstance(labels, dict) else None config = { "task": task, "labels": label_names, "multi_label": multi_label, "cls_threshold": cls_threshold, "true_label": ["N/A"], **kwargs } if label_descs: config["label_descriptions"] = label_descs self.schema["classifications"].append(config) return self def entities( self, entity_types: Union[str, List[str], Dict[str, Union[str, Dict]]], dtype: Literal["str", "list"] = "list", threshold: Optional[float] = None ) -> 'Schema': """Add entity extraction task.""" if self._active_builder: self._active_builder._auto_finish() self._active_builder = None entities = self._parse_entity_input(entity_types) for name, config in entities.items(): self.schema["entities"][name] = "" if name not in self._entity_order: self._entity_order.append(name) self._store_entity_metadata( name, config.get("dtype", dtype), config.get("threshold", threshold) ) if "description" in config: self.schema["entity_descriptions"][name] = config["description"] return self def _parse_entity_input(self, entity_types): if isinstance(entity_types, str): return {entity_types: {}} elif isinstance(entity_types, list): return {name: {} for name in entity_types} elif isinstance(entity_types, dict): result = {} for name, config in entity_types.items(): if isinstance(config, str): result[name] = {"description": config} elif isinstance(config, dict): result[name] = config else: result[name] = {} return result raise ValueError("Invalid entity_types format") def relations( self, relation_types: Union[str, List[str], Dict[str, Union[str, Dict]]], threshold: Optional[float] = None ) -> 'Schema': """Add relation extraction task.""" if self._active_builder: self._active_builder._auto_finish() self._active_builder = None if isinstance(relation_types, str): relations = {relation_types: {}} elif isinstance(relation_types, list): relations = {name: {} for name in relation_types} elif isinstance(relation_types, dict): relations = {} for name, config in relation_types.items(): relations[name] = {"description": config} if isinstance(config, str) else (config if isinstance(config, dict) else {}) else: raise ValueError("Invalid relation_types format") for name, config in relations.items(): self.schema["relations"].append({name: {"head": "", "tail": ""}}) if name not in self._relation_order: self._relation_order.append(name) self._field_orders[name] = ["head", "tail"] rel_threshold = config.get("threshold", threshold) if rel_threshold is not None and not 0 <= rel_threshold <= 1: raise ValueError(f"Threshold must be 0-1, got {rel_threshold}") self._relation_metadata[name] = {"threshold": rel_threshold} return self def build(self) -> Dict[str, Any]: """Build final schema dictionary.""" if self._active_builder: self._active_builder._auto_finish() self._active_builder = None return self.schema @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'Schema': """Create a Schema from a dictionary. Args: data: Dictionary with optional keys: entities, structures, classifications, relations Returns: Schema: Constructed schema instance Raises: ValidationError: If the input data is invalid Example: >>> schema_dict = { ... "entities": ["company", "person"], ... "structures": { ... "product_info": { ... "fields": [ ... {"name": "company", "dtype": "str"}, ... {"name": "product"} ... ] ... } ... }, ... "classifications": [ ... {"task": "sentiment", "labels": ["positive", "negative"]} ... ], ... "relations": ["works_for", "founded_by"] ... } >>> schema = Schema.from_dict(schema_dict) """ from gliner2.inference.schema_model import SchemaInput # Validate input validated = SchemaInput(**data) # Build schema using builder API schema = cls() # Add entities if validated.entities is not None: schema.entities(validated.entities) # Add structures if validated.structures is not None: for struct_name, struct_input in validated.structures.items(): builder = schema.structure(struct_name) for field_input in struct_input.fields: builder.field( name=field_input.name, dtype=field_input.dtype, choices=field_input.choices, description=field_input.description ) # Auto-finish the builder builder._auto_finish() # Add classifications if validated.classifications is not None: for cls_input in validated.classifications: schema.classification( task=cls_input.task, labels=cls_input.labels, multi_label=cls_input.multi_label ) # Add relations if validated.relations is not None: schema.relations(validated.relations) return schema @classmethod def from_json(cls, json_str: str) -> 'Schema': """Create a Schema from a JSON string. Args: json_str: JSON string with schema definition Returns: Schema: Constructed schema instance Raises: ValidationError: If the input data is invalid json.JSONDecodeError: If the JSON is malformed Example: >>> schema_json = ''' ... { ... "entities": ["company", "person"], ... "classifications": [ ... {"task": "sentiment", "labels": ["positive", "negative"]} ... ] ... } ... ''' >>> schema = Schema.from_json(schema_json) """ data = json.loads(json_str) return cls.from_dict(data) def to_dict(self) -> Dict[str, Any]: """Convert schema to user-friendly dictionary format. Returns: Dict: Schema in dictionary format compatible with from_dict() Example: >>> schema = Schema() >>> schema.entities(["company", "person"]) >>> schema_dict = schema.to_dict() >>> # schema_dict can be used with Schema.from_dict() """ result = {} # Export entities if self.schema["entities"]: # Check if we have descriptions if self.schema["entity_descriptions"]: result["entities"] = dict(self.schema["entity_descriptions"]) else: result["entities"] = list(self.schema["entities"].keys()) # Export structures if self.schema["json_structures"]: result["structures"] = {} for struct_dict in self.schema["json_structures"]: for struct_name, struct_fields in struct_dict.items(): fields = [] field_order = self._field_orders.get(struct_name, []) for field_name in field_order: if field_name not in struct_fields: continue field_key = f"{struct_name}.{field_name}" metadata = self._field_metadata.get(field_key, {}) field_def = {"name": field_name} # Add dtype if not default dtype = metadata.get("dtype", "list") if dtype != "list": field_def["dtype"] = dtype # Add choices if present choices = metadata.get("choices") if choices: field_def["choices"] = choices # Add description if present desc = self.schema.get("json_descriptions", {}).get(struct_name, {}).get(field_name) if desc: field_def["description"] = desc fields.append(field_def) result["structures"][struct_name] = {"fields": fields} # Export classifications if self.schema["classifications"]: result["classifications"] = [] for cls_config in self.schema["classifications"]: cls_def = { "task": cls_config["task"], "labels": cls_config["labels"] } if cls_config.get("multi_label", False): cls_def["multi_label"] = True result["classifications"].append(cls_def) # Export relations if self.schema["relations"]: result["relations"] = self._relation_order if self._relation_order else [ list(rel_dict.keys())[0] for rel_dict in self.schema["relations"] ] return result # ============================================================================= # Main GLiNER2 Class # ============================================================================= class GLiNER2(Extractor): """ GLiNER2 Information Extraction Model. Provides efficient batch extraction with parallel preprocessing. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._schema_cache = {} # OPT-11: Cached collator instance for inference self._inference_collator = None @classmethod def from_api(cls, api_key: str = None, api_base_url: str = None, timeout: float = 30.0, max_retries: int = 3) -> 'GLiNER2API': """Load from API instead of local model.""" from gliner2.api_client import GLiNER2API return GLiNER2API(api_key=api_key, api_base_url=api_base_url, timeout=timeout, max_retries=max_retries) def create_schema(self) -> Schema: """Create a new schema builder.""" return Schema() # ========================================================================= # Main Batch Extraction # ========================================================================= @torch.inference_mode() def batch_extract( self, texts: List[str], schemas: Union[Schema, List[Schema], Dict, List[Dict]], batch_size: int = 8, threshold: float = 0.5, num_workers: int = 0, format_results: bool = True, include_confidence: bool = False, include_spans: bool = False ) -> List[Dict[str, Any]]: """ Extract from multiple texts with parallel preprocessing. Args: texts: List of input texts schemas: Single schema or list of schemas batch_size: Batch size for processing threshold: Confidence threshold num_workers: Workers for parallel preprocessing format_results: Format output nicely include_confidence: Include confidence scores include_spans: Include character-level start/end positions Returns: List of extraction results """ if not texts: return [] self.eval() self.processor.change_mode(is_training=False) # Normalize schemas if isinstance(schemas, list): if len(schemas) != len(texts): raise ValueError(f"Schema count ({len(schemas)}) != text count ({len(texts)})") schema_list = schemas else: schema_list = [schemas] * len(texts) # Build schema dicts and metadata schema_dicts = [] metadata_list = [] for schema in schema_list: if hasattr(schema, 'build'): schema_dict = schema.build() # Extract classification task names classification_tasks = [c["task"] for c in schema_dict.get("classifications", [])] metadata = { "field_metadata": schema._field_metadata, "entity_metadata": schema._entity_metadata, "relation_metadata": getattr(schema, '_relation_metadata', {}), "field_orders": schema._field_orders, "entity_order": schema._entity_order, "relation_order": getattr(schema, '_relation_order', []), "classification_tasks": classification_tasks } else: schema_dict = schema # Extract classification task names from dict schema classification_tasks = [c["task"] for c in schema_dict.get("classifications", [])] metadata = { "field_metadata": {}, "entity_metadata": {}, "relation_metadata": {}, "field_orders": {}, "entity_order": [], "relation_order": [], "classification_tasks": classification_tasks } # Ensure classifications have true_label for cls_config in schema_dict.get("classifications", []): cls_config.setdefault("true_label", ["N/A"]) schema_dicts.append(schema_dict) metadata_list.append(metadata) # OPT-9: Skip duplicate normalization — _collate_batch handles it dataset = list(zip(texts, schema_dicts)) # OPT-11: Reuse cached collator instance if self._inference_collator is None: self._inference_collator = ExtractorCollator(self.processor, is_training=False) collator = self._inference_collator # OPT-12: Skip DataLoader overhead for single-batch inputs if len(dataset) <= batch_size and num_workers == 0: batches = [collator(dataset)] else: batches = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collator, pin_memory=True if torch.cuda.is_available() else False, ) # Process batches all_results = [] sample_idx = 0 device = next(self.parameters()).device for batch in batches: batch = batch.to(device) batch_results = self._extract_from_batch( batch, threshold, metadata_list[sample_idx:sample_idx + len(batch)], include_confidence, include_spans ) if format_results: for i, result in enumerate(batch_results): meta = metadata_list[sample_idx + i] requested_relations = meta.get("relation_order", []) classification_tasks = meta.get("classification_tasks", []) batch_results[i] = self.format_results( result, include_confidence, requested_relations, classification_tasks ) all_results.extend(batch_results) sample_idx += len(batch) return all_results def _extract_from_batch( self, batch: PreprocessedBatch, threshold: float, metadata_list: List[Dict], include_confidence: bool, include_spans: bool ) -> List[Dict[str, Any]]: """Extract from preprocessed batch.""" # Encode batch all_token_embs, all_schema_embs = self.processor.extract_embeddings_from_batch( self.encoder( input_ids=batch.input_ids, attention_mask=batch.attention_mask ).last_hidden_state, batch.input_ids, batch ) results = [] for i in range(len(batch)): try: sample_result = self._extract_sample( token_embs=all_token_embs[i], schema_embs=all_schema_embs[i], schema_tokens_list=batch.schema_tokens_list[i], task_types=batch.task_types[i], text_tokens=batch.text_tokens[i], original_text=batch.original_texts[i], schema=batch.original_schemas[i], start_mapping=batch.start_mappings[i], end_mapping=batch.end_mappings[i], threshold=threshold, metadata=metadata_list[i], include_confidence=include_confidence, include_spans=include_spans ) results.append(sample_result) except Exception as e: print(f"Error extracting sample {i}: {e}") results.append({}) return results def _extract_sample( self, token_embs: torch.Tensor, schema_embs: List[List[torch.Tensor]], schema_tokens_list: List[List[str]], task_types: List[str], text_tokens: List[str], original_text: str, schema: Dict, start_mapping: List[int], end_mapping: List[int], threshold: float, metadata: Dict, include_confidence: bool, include_spans: bool ) -> Dict[str, Any]: """Extract from single sample.""" results = {} # Compute span representations if needed has_span_task = any(t != "classifications" for t in task_types) span_info = None if has_span_task and token_embs.numel() > 0: span_info = self.compute_span_rep(token_embs) # Build classification field map cls_fields = {} for struct in schema.get("json_structures", []): for parent, fields in struct.items(): for fname, fval in fields.items(): if isinstance(fval, dict) and "choices" in fval: cls_fields[f"{parent}.{fname}"] = fval["choices"] # OPT-3: Use start_mapping length instead of re-tokenizing text text_len = len(start_mapping) for i, (schema_tokens, task_type) in enumerate(zip(schema_tokens_list, task_types)): if len(schema_tokens) < 4 or not schema_embs[i]: continue schema_name = schema_tokens[2].split(" [DESCRIPTION] ")[0] embs = torch.stack(schema_embs[i]) if task_type == "classifications": self._extract_classification_result( results, schema_name, schema, embs, schema_tokens ) else: self._extract_span_result( results, schema_name, task_type, embs, span_info, schema_tokens, text_tokens, text_len, original_text, start_mapping, end_mapping, threshold, metadata, cls_fields, include_confidence, include_spans ) return results def _extract_classification_result( self, results: Dict, schema_name: str, schema: Dict, embs: torch.Tensor, schema_tokens: List[str] ): """Extract classification result.""" cls_config = next( c for c in schema["classifications"] if schema_tokens[2].startswith(c["task"]) ) cls_embeds = embs[1:] logits = self.classifier(cls_embeds).squeeze(-1) activation = cls_config.get("class_act", "auto") is_multi = cls_config.get("multi_label", False) if activation == "sigmoid": probs = torch.sigmoid(logits) elif activation == "softmax": probs = torch.softmax(logits, dim=-1) else: probs = torch.sigmoid(logits) if is_multi else torch.softmax(logits, dim=-1) labels = cls_config["labels"] cls_threshold = cls_config.get("cls_threshold", 0.5) if is_multi: chosen = [(labels[j], probs[j].item()) for j in range(len(labels)) if probs[j].item() >= cls_threshold] if not chosen: best = int(torch.argmax(probs).item()) chosen = [(labels[best], probs[best].item())] results[schema_name] = chosen else: best = int(torch.argmax(probs).item()) results[schema_name] = (labels[best], probs[best].item()) def _extract_span_result( self, results: Dict, schema_name: str, task_type: str, embs: torch.Tensor, span_info: Dict, schema_tokens: List[str], text_tokens: List[str], text_len: int, original_text: str, start_mapping: List[int], end_mapping: List[int], threshold: float, metadata: Dict, cls_fields: Dict, include_confidence: bool, include_spans: bool ): """Extract span-based results.""" # Get field names field_names = [] for j in range(len(schema_tokens) - 1): if schema_tokens[j] in ("[E]", "[C]", "[R]"): field_names.append(schema_tokens[j + 1]) if not field_names: results[schema_name] = [] if schema_name == "entities" else {} return # Predict count count_logits = self.count_pred(embs[0].unsqueeze(0)) pred_count = int(count_logits.argmax(dim=1).item()) if pred_count <= 0 or span_info is None: if schema_name == "entities": results[schema_name] = [] elif task_type == "relations": results[schema_name] = [] else: results[schema_name] = {} return # Get span scores struct_proj = self.count_embed(embs[1:], pred_count) span_scores = torch.sigmoid( torch.einsum("lkd,bpd->bplk", span_info["span_rep"], struct_proj) ) # Extract based on type if schema_name == "entities": results[schema_name] = self._extract_entities( field_names, span_scores, text_len, text_tokens, original_text, start_mapping, end_mapping, threshold, metadata, include_confidence, include_spans ) elif task_type == "relations": results[schema_name] = self._extract_relations( schema_name, field_names, span_scores, pred_count, text_len, text_tokens, original_text, start_mapping, end_mapping, threshold, metadata, include_confidence, include_spans ) else: results[schema_name] = self._extract_structures( schema_name, field_names, span_scores, pred_count, text_len, text_tokens, original_text, start_mapping, end_mapping, threshold, metadata, cls_fields, include_confidence, include_spans ) def _extract_entities( self, entity_names: List[str], span_scores: torch.Tensor, text_len: int, text_tokens: List[str], text: str, start_map: List[int], end_map: List[int], threshold: float, metadata: Dict, include_confidence: bool, include_spans: bool ) -> List[Dict]: """Extract entity results.""" scores = span_scores[0, :, -text_len:] entity_results = OrderedDict() for name in metadata.get("entity_order", entity_names): if name not in entity_names: continue idx = entity_names.index(name) meta = metadata.get("entity_metadata", {}).get(name, {}) meta_threshold = meta.get("threshold") ent_threshold = meta_threshold if meta_threshold is not None else threshold dtype = meta.get("dtype", "list") spans = self._find_spans( scores[idx], ent_threshold, text_len, text, start_map, end_map ) if dtype == "list": entity_results[name] = self._format_spans(spans, include_confidence, include_spans) else: if spans: text_val, conf, char_start, char_end = spans[0] if include_spans and include_confidence: entity_results[name] = { "text": text_val, "confidence": conf, "start": char_start, "end": char_end } elif include_spans: entity_results[name] = { "text": text_val, "start": char_start, "end": char_end } elif include_confidence: entity_results[name] = {"text": text_val, "confidence": conf} else: entity_results[name] = text_val else: entity_results[name] = "" if not include_spans and not include_confidence else None return [entity_results] if entity_results else [] def _extract_relations( self, rel_name: str, field_names: List[str], span_scores: torch.Tensor, count: int, text_len: int, text_tokens: List[str], text: str, start_map: List[int], end_map: List[int], threshold: float, metadata: Dict, include_confidence: bool, include_spans: bool ) -> List[Union[Tuple[str, str], Dict]]: """Extract relation results with optional confidence and position info.""" instances = [] rel_threshold = threshold if rel_name in metadata.get("relation_metadata", {}): meta_threshold = metadata["relation_metadata"][rel_name].get("threshold") rel_threshold = meta_threshold if meta_threshold is not None else threshold ordered_fields = metadata.get("field_orders", {}).get(rel_name, field_names) for inst in range(count): scores = span_scores[inst, :, -text_len:] values = [] field_data = [] # Store full data for each field for fname in ordered_fields: if fname not in field_names: continue fidx = field_names.index(fname) spans = self._find_spans( scores[fidx], rel_threshold, text_len, text, start_map, end_map ) if spans: text_val, conf, char_start, char_end = spans[0] values.append(text_val) field_data.append({ "text": text_val, "confidence": conf, "start": char_start, "end": char_end }) else: values.append(None) field_data.append(None) if len(values) == 2 and values[0] and values[1]: # Format based on flags if include_spans and include_confidence: instances.append({ "head": field_data[0], "tail": field_data[1] }) elif include_spans: instances.append({ "head": {"text": field_data[0]["text"], "start": field_data[0]["start"], "end": field_data[0]["end"]}, "tail": {"text": field_data[1]["text"], "start": field_data[1]["start"], "end": field_data[1]["end"]} }) elif include_confidence: instances.append({ "head": {"text": field_data[0]["text"], "confidence": field_data[0]["confidence"]}, "tail": {"text": field_data[1]["text"], "confidence": field_data[1]["confidence"]} }) else: # Original tuple format for backward compatibility instances.append((values[0], values[1])) return instances def _extract_structures( self, struct_name: str, field_names: List[str], span_scores: torch.Tensor, count: int, text_len: int, text_tokens: List[str], text: str, start_map: List[int], end_map: List[int], threshold: float, metadata: Dict, cls_fields: Dict, include_confidence: bool, include_spans: bool ) -> List[Dict]: """Extract structure results with optional position tracking.""" instances = [] ordered_fields = metadata.get("field_orders", {}).get(struct_name, field_names) for inst in range(count): scores = span_scores[inst, :, -text_len:] instance = OrderedDict() for fname in ordered_fields: if fname not in field_names: continue fidx = field_names.index(fname) field_key = f"{struct_name}.{fname}" meta = metadata.get("field_metadata", {}).get(field_key, {}) meta_threshold = meta.get("threshold") field_threshold = meta_threshold if meta_threshold is not None else threshold dtype = meta.get("dtype", "list") validators = meta.get("validators", []) if field_key in cls_fields: # Classification field - no span positions needed choices = cls_fields[field_key] prefix_scores = span_scores[inst, fidx, :-text_len] if dtype == "list": selected = [] seen = set() for choice in choices: if choice in seen: continue idx = self._find_choice_idx(choice, text_tokens[:-text_len]) if idx >= 0 and idx < prefix_scores.shape[0]: score = prefix_scores[idx, 0].item() if score >= field_threshold: if include_confidence: selected.append({"text": choice, "confidence": score}) else: selected.append(choice) seen.add(choice) instance[fname] = selected else: best = None best_score = -1.0 for choice in choices: idx = self._find_choice_idx(choice, text_tokens[:-text_len]) if idx >= 0 and idx < prefix_scores.shape[0]: score = prefix_scores[idx, 0].item() if score > best_score: best_score = score best = choice if best and best_score >= field_threshold: if include_confidence: instance[fname] = {"text": best, "confidence": best_score} else: instance[fname] = best else: instance[fname] = None else: # Regular span field - track positions spans = self._find_spans( scores[fidx], field_threshold, text_len, text, start_map, end_map ) if validators: spans = [s for s in spans if all(v.validate(s[0]) for v in validators)] if dtype == "list": instance[fname] = self._format_spans(spans, include_confidence, include_spans) else: if spans: text_val, conf, char_start, char_end = spans[0] if include_spans and include_confidence: instance[fname] = { "text": text_val, "confidence": conf, "start": char_start, "end": char_end } elif include_spans: instance[fname] = { "text": text_val, "start": char_start, "end": char_end } elif include_confidence: instance[fname] = {"text": text_val, "confidence": conf} else: instance[fname] = text_val else: instance[fname] = None # Only add if has content if any(v is not None and v != [] for v in instance.values()): instances.append(instance) return instances def _find_spans( self, scores: torch.Tensor, threshold: float, text_len: int, text: str, start_map: List[int], end_map: List[int] ) -> List[Tuple[str, float, int, int]]: """Find valid spans above threshold. Returns (text, confidence, char_start, char_end).""" valid = torch.where(scores >= threshold) starts, widths = valid spans = [] for start, width in zip(starts.tolist(), widths.tolist()): end = start + width + 1 if 0 <= start < text_len and end <= text_len: try: char_start = start_map[start] char_end = end_map[end - 1] text_span = text[char_start:char_end].strip() except (IndexError, KeyError): continue if text_span: conf = scores[start, width].item() spans.append((text_span, conf, char_start, char_end)) return spans def _format_spans( self, spans: List[Tuple], include_confidence: bool, include_spans: bool = False ) -> Union[List[str], List[Dict], List[Tuple]]: """Format spans with overlap removal and optional position info.""" if not spans: return [] sorted_spans = sorted(spans, key=lambda x: x[1], reverse=True) selected = [] for text, conf, start, end in sorted_spans: overlap = any(not (end <= s[2] or start >= s[3]) for s in selected) if not overlap: selected.append((text, conf, start, end)) # Format based on flags if include_spans and include_confidence: return [{"text": s[0], "confidence": s[1], "start": s[2], "end": s[3]} for s in selected] elif include_spans: return [{"text": s[0], "start": s[2], "end": s[3]} for s in selected] elif include_confidence: return [{"text": s[0], "confidence": s[1]} for s in selected] else: return [s[0] for s in selected] def _find_choice_idx(self, choice: str, tokens: List[str]) -> int: """Find index of choice in tokens.""" choice_lower = choice.lower() for i, tok in enumerate(tokens): if tok.lower() == choice_lower or choice_lower in tok.lower(): return i return -1 # ========================================================================= # Result Formatting # ========================================================================= def format_results( self, results: Dict, include_confidence: bool = False, requested_relations: List[str] = None, classification_tasks: List[str] = None ) -> Dict[str, Any]: """Format extraction results.""" formatted = {} relations = {} requested_relations = requested_relations or [] classification_tasks = classification_tasks or [] for key, value in results.items(): # Check if this is a classification task (takes priority) is_classification = key in classification_tasks # Check if this is a relation is_relation = False if not is_classification: # Check if key is in requested_relations (this takes priority) if key in requested_relations: is_relation = True # Otherwise, check the value structure elif isinstance(value, list) and len(value) > 0: # Check for tuple format: [(head, tail), ...] if isinstance(value[0], tuple) and len(value[0]) == 2: is_relation = True # Check for dict format with head/tail keys: [{"head": ..., "tail": ...}, ...] elif isinstance(value[0], dict) and "head" in value[0] and "tail" in value[0]: is_relation = True if is_classification: # This is a classification task - format and add to formatted dict directly if isinstance(value, list): # Multi-label classification if include_confidence: formatted[key] = [{"label": l, "confidence": c} for l, c in value] else: formatted[key] = [l for l, _ in value] elif isinstance(value, tuple): # Single-label classification label, conf = value formatted[key] = {"label": label, "confidence": conf} if include_confidence else label else: formatted[key] = value elif is_relation: # This is a relation - store in relations dict, not formatted # Relations should always be lists, but handle edge cases defensively if isinstance(value, list): relations[key] = value else: # Unexpected non-list value for relation - convert to empty list relations[key] = [] elif isinstance(value, list): if len(value) == 0: if key == "entities": formatted[key] = {} else: formatted[key] = value elif isinstance(value[0], dict): if key == "entities": formatted[key] = self._format_entity_dict(value[0], include_confidence) else: formatted[key] = [self._format_struct(v, include_confidence) for v in value] elif isinstance(value[0], tuple): if include_confidence: formatted[key] = [{"label": l, "confidence": c} for l, c in value] else: formatted[key] = [l for l, _ in value] else: formatted[key] = value elif isinstance(value, tuple): label, conf = value formatted[key] = {"label": label, "confidence": conf} if include_confidence else label elif isinstance(value, dict): formatted[key] = self._format_struct(value, include_confidence) else: formatted[key] = value # Add all requested relations (including empty ones) for rel in requested_relations: if rel not in relations: relations[rel] = [] # Only add relation_extraction if we have relations if relations: formatted["relation_extraction"] = relations return formatted def _format_entity_dict(self, entities: Dict, include_confidence: bool) -> Dict: formatted = {} for name, spans in entities.items(): if isinstance(spans, list): unique = [] seen = set() for span in spans: if isinstance(span, tuple): text, conf, start, end = span if text and text.lower() not in seen: seen.add(text.lower()) unique.append({"text": text, "confidence": conf} if include_confidence else text) elif isinstance(span, dict): # Handle dict format (with confidence/spans) text = span.get("text", "") if text and text.lower() not in seen: seen.add(text.lower()) unique.append(span) else: # Handle string format if span and span.lower() not in seen: seen.add(span.lower()) unique.append(span) formatted[name] = unique elif isinstance(spans, tuple): text, conf, _, _ = spans formatted[name] = {"text": text, "confidence": conf} if include_confidence and text else text else: formatted[name] = spans or None return formatted def _format_struct(self, struct: Dict, include_confidence: bool) -> Dict: formatted = {} for field, value in struct.items(): if isinstance(value, list): unique = [] seen = set() for v in value: if isinstance(v, tuple): text, conf, _, _ = v if text and text.lower() not in seen: seen.add(text.lower()) unique.append({"text": text, "confidence": conf} if include_confidence else text) elif isinstance(v, dict): # Handle dict format (with confidence/spans) text = v.get("text", "") if text and text.lower() not in seen: seen.add(text.lower()) unique.append(v) else: # Handle string format if v and v.lower() not in seen: seen.add(v.lower()) unique.append(v) formatted[field] = unique elif isinstance(value, tuple): text, conf, _, _ = value formatted[field] = {"text": text, "confidence": conf} if include_confidence and text else text elif value: formatted[field] = value else: formatted[field] = None return formatted # ========================================================================= # Convenience Methods (route through batch) # ========================================================================= def extract(self, text: str, schema, threshold: float = 0.5, format_results: bool = True, include_confidence: bool = False, include_spans: bool = False) -> Dict: """Extract from single text.""" return self.batch_extract([text], schema, 1, threshold, 0, format_results, include_confidence, include_spans)[0] def extract_entities(self, text: str, entity_types, threshold: float = 0.5, format_results: bool = True, include_confidence: bool = False, include_spans: bool = False) -> Dict: """Extract entities from text.""" schema = self.create_schema().entities(entity_types) return self.extract(text, schema, threshold, format_results, include_confidence, include_spans) def batch_extract_entities(self, texts: List[str], entity_types, batch_size: int = 8, threshold: float = 0.5, format_results: bool = True, include_confidence: bool = False, include_spans: bool = False) -> List[Dict]: """Batch extract entities.""" schema = self.create_schema().entities(entity_types) return self.batch_extract(texts, schema, batch_size, threshold, 0, format_results, include_confidence, include_spans) def classify_text(self, text: str, tasks: Dict, threshold: float = 0.5, format_results: bool = True, include_confidence: bool = False, include_spans: bool = False) -> Dict: """Classify text.""" schema = self.create_schema() for name, config in tasks.items(): if isinstance(config, dict) and "labels" in config: cfg = config.copy() labels = cfg.pop("labels") schema.classification(name, labels, **cfg) else: schema.classification(name, config) return self.extract(text, schema, threshold, format_results, include_confidence, include_spans) def batch_classify_text(self, texts: List[str], tasks: Dict, batch_size: int = 8, threshold: float = 0.5, format_results: bool = True, include_confidence: bool = False, include_spans: bool = False) -> List[Dict]: """Batch classify texts.""" schema = self.create_schema() for name, config in tasks.items(): if isinstance(config, dict) and "labels" in config: cfg = config.copy() labels = cfg.pop("labels") schema.classification(name, labels, **cfg) else: schema.classification(name, config) return self.batch_extract(texts, schema, batch_size, threshold, 0, format_results, include_confidence, include_spans) def extract_json(self, text: str, structures: Dict, threshold: float = 0.5, format_results: bool = True, include_confidence: bool = False, include_spans: bool = False) -> Dict: """Extract structured data.""" schema = self.create_schema() for parent, fields in structures.items(): builder = schema.structure(parent) for spec in fields: name, dtype, choices, desc = self._parse_field_spec(spec) builder.field(name, dtype=dtype, choices=choices, description=desc) return self.extract(text, schema, threshold, format_results, include_confidence, include_spans) def batch_extract_json(self, texts: List[str], structures: Dict, batch_size: int = 8, threshold: float = 0.5, format_results: bool = True, include_confidence: bool = False, include_spans: bool = False) -> List[Dict]: """Batch extract structured data.""" schema = self.create_schema() for parent, fields in structures.items(): builder = schema.structure(parent) for spec in fields: name, dtype, choices, desc = self._parse_field_spec(spec) builder.field(name, dtype=dtype, choices=choices, description=desc) return self.batch_extract(texts, schema, batch_size, threshold, 0, format_results, include_confidence, include_spans) def extract_relations(self, text: str, relation_types, threshold: float = 0.5, format_results: bool = True, include_confidence: bool = False, include_spans: bool = False) -> Dict: """Extract relations.""" schema = self.create_schema().relations(relation_types) return self.extract(text, schema, threshold, format_results, include_confidence, include_spans) def batch_extract_relations(self, texts: List[str], relation_types, batch_size: int = 8, threshold: float = 0.5, format_results: bool = True, include_confidence: bool = False, include_spans: bool = False) -> List[Dict]: """Batch extract relations.""" schema = self.create_schema().relations(relation_types) return self.batch_extract(texts, schema, batch_size, threshold, 0, format_results, include_confidence, include_spans) def _parse_field_spec(self, spec: Union[str, Dict]) -> Tuple[str, str, Optional[List[str]], Optional[str]]: """Parse field specification string or dictionary. Format: "name::dtype::choices::description" where all parts after name are optional. - dtype: 'str' for single value, 'list' for multiple values - choices: [option1|option2|...] for enumerated options - description: free text description Examples: "restaurant::str::Restaurant name" "seating::[indoor|outdoor|bar]::Seating preference" "dietary::[vegetarian|vegan|gluten-free|none]::list::Dietary restrictions" """ if isinstance(spec, dict): return ( spec.get("name", ""), spec.get("dtype", "list"), spec.get("choices"), spec.get("description") ) parts = spec.split('::') name = parts[0] dtype, choices, desc = "list", None, None dtype_explicitly_set = False if len(parts) == 1: return name, dtype, choices, desc for part in parts[1:]: if part in ['str', 'list']: dtype = part dtype_explicitly_set = True elif part.startswith('[') and part.endswith(']'): choices = [c.strip() for c in part[1:-1].split('|')] # Only default to "str" if dtype wasn't explicitly set if not dtype_explicitly_set: dtype = "str" else: desc = part return name, dtype, choices, desc # Aliases BuilderExtractor = GLiNER2 SchemaBuilder = Schema JsonStructBuilder = StructureBuilder