""" GLiNER2 Training Data Creation & Validation Module This module provides intuitive classes for creating, validating, and managing training data for GLiNER2 models. Quick Examples -------------- Create entity examples: >>> example = InputExample( ... text="John works at Google in NYC.", ... entities={"person": ["John"], "company": ["Google"], "location": ["NYC"]} ... ) Create classification examples: >>> example = InputExample( ... text="This movie is amazing!", ... classifications=[ ... Classification(task="sentiment", labels=["positive", "negative"], true_label="positive") ... ] ... ) Create structured data examples: >>> example = InputExample( ... text="iPhone 15 costs $999", ... structures=[ ... Structure("product", name="iPhone 15", price="$999") ... ] ... ) Create relation examples: >>> example = InputExample( ... text="Elon Musk founded SpaceX.", ... relations=[ ... Relation("founded", head="Elon Musk", tail="SpaceX") ... ] ... ) Build and validate dataset: >>> dataset = TrainingDataset(examples) >>> dataset.validate() # Raises ValidationError if invalid >>> dataset.save("train.jsonl") Load from JSONL: >>> dataset = TrainingDataset.load("train.jsonl") >>> # Or load multiple files >>> dataset = TrainingDataset.load(["train1.jsonl", "train2.jsonl"]) """ from __future__ import annotations import json import random from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Union, Tuple, Iterator, TYPE_CHECKING from collections import Counter from tqdm import tqdm if TYPE_CHECKING: # Forward declarations for type checking only pass class ValidationError(Exception): """Raised when training data validation fails.""" def __init__(self, message: str, errors: List[str] = None): super().__init__(message) self.errors = errors or [] def __str__(self): if self.errors: error_list = "\n - ".join(self.errors[:10]) suffix = f"\n ... and {len(self.errors) - 10} more errors" if len(self.errors) > 10 else "" return f"{self.args[0]}\n - {error_list}{suffix}" return self.args[0] # ============================================================================= # Data Format Detection & Loading # ============================================================================= class DataFormat: """Enum-like class for supported data formats.""" JSONL = "jsonl" JSONL_LIST = "jsonl_list" INPUT_EXAMPLE_LIST = "input_example_list" TRAINING_DATASET = "training_dataset" DICT_LIST = "dict_list" EXTRACTOR_DATASET = "extractor_dataset" def detect_data_format(data: Any) -> str: """ Detect the format of input data. Parameters ---------- data : Any Input data in any supported format. Returns ------- str The detected format name from DataFormat. """ # String path if isinstance(data, str): return DataFormat.JSONL # Path object if isinstance(data, Path): return DataFormat.JSONL # List types if isinstance(data, list) and len(data) > 0: first = data[0] if isinstance(first, (str, Path)): return DataFormat.JSONL_LIST if isinstance(first, InputExample): return DataFormat.INPUT_EXAMPLE_LIST if isinstance(first, dict): return DataFormat.DICT_LIST # Empty list - default to dict list if isinstance(data, list) and len(data) == 0: return DataFormat.DICT_LIST # TrainingDataset if isinstance(data, TrainingDataset): return DataFormat.TRAINING_DATASET # ExtractorDataset (internal) - forward reference if type(data).__name__ == 'ExtractorDataset': return DataFormat.EXTRACTOR_DATASET raise ValueError(f"Unsupported data format: {type(data)}") class DataLoader_Factory: """ Factory for loading data from various formats into a unified internal format. All loaders convert data to List[Dict] format where each dict has: - "input": str (the text) - "output": Dict (the schema/annotations) Or alternatively: - "text": str - "schema": Dict """ @staticmethod def load( data: Any, max_samples: int = -1, shuffle: bool = True, seed: int = 42, validate: bool = False, ) -> List[Dict[str, Any]]: """ Load data from any supported format. Parameters ---------- data : Any Input data in any supported format. max_samples : int, default=-1 Maximum samples to load (-1 = all). shuffle : bool, default=True Whether to shuffle the data. seed : int, default=42 Random seed for shuffling. validate : bool, default=False Whether to validate the data. Validation is always strict: checks that entity spans, relation values, and structure field values exist in the text. Returns ------- List[Dict[str, Any]] List of records in unified format. """ fmt = detect_data_format(data) # Load based on format if fmt == DataFormat.JSONL: records = DataLoader_Factory._load_jsonl(data) elif fmt == DataFormat.JSONL_LIST: records = DataLoader_Factory._load_jsonl_list(data) elif fmt == DataFormat.INPUT_EXAMPLE_LIST: records = DataLoader_Factory._load_input_examples(data) elif fmt == DataFormat.TRAINING_DATASET: records = DataLoader_Factory._load_training_dataset(data) elif fmt == DataFormat.DICT_LIST: records = DataLoader_Factory._load_dict_list(data) elif fmt == DataFormat.EXTRACTOR_DATASET: records = data.data.copy() else: raise ValueError(f"Unsupported data format: {type(data)}") # Validate if requested if validate and records: valid_indices, invalid_info = DataLoader_Factory._validate_records(records) if invalid_info: total_records = len(records) num_invalid = len(invalid_info) num_valid = len(valid_indices) print(f"\nValidation: Found {num_invalid} invalid record(s) out of {total_records} total") print("Removed invalid records:") # Print first 5 invalid records for idx, (record_idx, record, errors) in enumerate(invalid_info[:5]): # Print first error for this record error_msg = errors[0] if errors else "Unknown error" print(f" Record {record_idx}: {error_msg}") if num_invalid > 5: print(f" ... and {num_invalid - 5} more invalid record(s)") print(f"Kept {num_valid} valid record(s)\n") # Filter records to keep only valid ones records = [records[i] for i in valid_indices] # Shuffle if shuffle and records: random.seed(seed) random.shuffle(records) # Limit samples if max_samples > 0 and len(records) > max_samples: records = records[:max_samples] return records @staticmethod def _load_jsonl(path: Union[str, Path]) -> List[Dict]: """Load from single JSONL file.""" path = Path(path) if not path.exists(): raise FileNotFoundError(f"File not found: {path}") records = [] with open(path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): line = line.strip() if line: try: records.append(json.loads(line)) except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON in {path} line {line_num}: {e}") return records @staticmethod def _load_jsonl_list(paths: List[Union[str, Path]]) -> List[Dict]: """Load from multiple JSONL files.""" records = [] for path in paths: records.extend(DataLoader_Factory._load_jsonl(path)) return records @staticmethod def _load_input_examples(examples: List[InputExample]) -> List[Dict]: """Load from list of InputExample objects.""" return [ex.to_dict() for ex in examples] @staticmethod def _load_training_dataset(dataset: TrainingDataset) -> List[Dict]: """Load from TrainingDataset object.""" return dataset.to_records() @staticmethod def _load_dict_list(dicts: List[Dict]) -> List[Dict]: """Load from list of dicts.""" if not dicts: return [] first = dicts[0] # Check format if "input" in first and "output" in first: # Already in correct format return dicts elif "text" in first and "schema" in first: # Alternative format - keep as is (handled in __getitem__) return dicts elif "text" in first: # Maybe has entities/classifications at top level - try to convert records = [] for d in dicts: output = {} if "entities" in d: output["entities"] = d["entities"] if "classifications" in d: output["classifications"] = d["classifications"] if "relations" in d: output["relations"] = d["relations"] if "json_structures" in d: output["json_structures"] = d["json_structures"] records.append({"input": d["text"], "output": output}) return records else: raise ValueError( f"Unknown dict format. Expected keys like 'input'/'output', 'text'/'schema', " f"or 'text' with annotation keys. Got: {list(first.keys())}" ) @staticmethod def _validate_records(records: List[Dict]) -> Tuple[List[int], List[Tuple[int, Dict, List[str]]]]: """ Validate and sanitize all records, removing only invalid parts. Uses granular sanitization: - Entities: Drop entity type if any mention not found - Classifications: Drop individual invalid classifications - Structures: Remove invalid fields, drop if all invalid - Relations: Drop relation if any field invalid - Record: Drop only if no valid tasks remain Returns ------- Tuple[List[int], List[Tuple[int, Dict, List[str]]]] - First element: List of valid record indices (sanitized records replace originals) - Second element: List of (index, original_record, warning_messages) for dropped records """ valid_indices = [] invalid_info = [] for i, record in tqdm(enumerate(records), total=len(records), desc="Validating records", unit="record"): warnings = [] try: example = InputExample.from_dict(record) sanitize_warnings, is_valid = example.sanitize() if is_valid: # Replace record with sanitized version records[i] = example.to_dict() valid_indices.append(i) if sanitize_warnings: # Record was sanitized but still valid warnings.extend(sanitize_warnings) else: # No valid content remains warnings.extend(sanitize_warnings) invalid_info.append((i, record, warnings)) except Exception as e: warnings.append(f"Failed to parse - {e}") invalid_info.append((i, record, warnings)) return valid_indices, invalid_info # Type alias for flexible data input TrainDataInput = Union[ str, # Single JSONL path Path, # Single JSONL path List[str], # Multiple JSONL paths List[Path], # Multiple JSONL paths List[Dict[str, Any]], # Raw records 'TrainingDataset', # TrainingDataset (forward reference) 'List[InputExample]', # List of InputExample (forward reference) 'ExtractorDataset', # Legacy dataset (forward reference) ] # ============================================================================= # Training Data Classes # ============================================================================= @dataclass class Classification: """ A classification task definition. Parameters ---------- task : str Name of the classification task (e.g., "sentiment", "category"). labels : List[str] All possible labels for this task. true_label : str or List[str] The correct label(s) for this example. multi_label : bool, default=False Whether multiple labels can be selected. prompt : str, optional Custom prompt for the task. examples : List[Tuple[str, str]], optional Few-shot examples as (input, output) pairs. label_descriptions : Dict[str, str], optional Descriptions for each label. """ task: str labels: List[str] true_label: Union[str, List[str]] multi_label: bool = False prompt: Optional[str] = None examples: Optional[List[Tuple[str, str]]] = None label_descriptions: Optional[Dict[str, str]] = None def __post_init__(self): if isinstance(self.true_label, str): self._true_label_list = [self.true_label] else: self._true_label_list = list(self.true_label) # Auto-infer multi_label=True when multiple true labels are provided if len(self._true_label_list) > 1: self.multi_label = True def validate(self) -> List[str]: """Validate this classification and return list of errors.""" errors = [] if not self.task: errors.append("Classification task name cannot be empty") if not self.labels: errors.append(f"Classification '{self.task}' has no labels") for label in self._true_label_list: if label not in self.labels: errors.append(f"True label '{label}' not in labels list for task '{self.task}'") if len(self._true_label_list) > 1 and not self.multi_label: errors.append(f"Multiple true labels provided for '{self.task}' but multi_label=False") if self.label_descriptions: for key in self.label_descriptions: if key not in self.labels: errors.append(f"Label description key '{key}' not in labels for task '{self.task}'") if self.examples: for i, ex in enumerate(self.examples): if not isinstance(ex, (list, tuple)) or len(ex) != 2: errors.append(f"Example {i} for task '{self.task}' must be (input, output) pair") return errors def to_dict(self) -> Dict[str, Any]: """Convert to training format dictionary.""" result = {"task": self.task, "labels": self.labels, "true_label": self._true_label_list} if self.multi_label: result["multi_label"] = True if self.prompt: result["prompt"] = self.prompt if self.examples: result["examples"] = [list(ex) for ex in self.examples] if self.label_descriptions: result["label_descriptions"] = self.label_descriptions return result @dataclass class ChoiceField: """ A field with predefined choices (classification within structure). Parameters ---------- value : str The selected value. choices : List[str] All possible choices. """ value: str choices: List[str] def validate(self, field_name: str) -> List[str]: errors = [] if self.value not in self.choices: errors.append(f"Choice value '{self.value}' not in choices {self.choices} for field '{field_name}'") return errors def to_dict(self) -> Dict[str, Any]: return {"value": self.value, "choices": self.choices} @dataclass class Structure: """ A structured data extraction definition. Parameters ---------- struct_name : str Name of the structure (e.g., "product", "contact"). **fields : Any Field names and values. Values can be: - str: Single string value - List[str]: Multiple values - ChoiceField: Classification-style field with choices Examples -------- >>> struct = Structure("product", name="iPhone", price="$999") >>> struct = Structure("contact", name="John", email="john@example.com") """ struct_name: str _fields: Dict[str, Any] = field(default_factory=dict) descriptions: Optional[Dict[str, str]] = None def __init__(self, struct_name: str, _descriptions: Dict[str, str] = None, **fields): self.struct_name = struct_name self._fields = fields self.descriptions = _descriptions def validate(self, text: str) -> List[str]: """ Validate this structure. Parameters ---------- text : str The text to validate against. Field values must exist in this text. Returns ------- List[str] List of validation errors. """ errors = [] if not self.struct_name: errors.append("Structure name cannot be empty") if not self._fields: errors.append(f"Structure '{self.struct_name}' has no fields") for field_name, value in self._fields.items(): if isinstance(value, ChoiceField): errors.extend(value.validate(f"{self.struct_name}.{field_name}")) elif isinstance(value, list): for i, v in enumerate(value): if v and v.lower() not in text.lower(): errors.append(f"List value '{v}' at index {i} in '{self.struct_name}.{field_name}' not found in text") elif isinstance(value, str): if value and value.lower() not in text.lower(): errors.append(f"Value '{value}' for '{self.struct_name}.{field_name}' not found in text") return errors def to_dict(self) -> Dict[str, Dict[str, Any]]: fields_dict = {} for field_name, value in self._fields.items(): if isinstance(value, ChoiceField): fields_dict[field_name] = value.to_dict() else: fields_dict[field_name] = value return {self.struct_name: fields_dict} def get_descriptions(self) -> Optional[Dict[str, Dict[str, str]]]: if self.descriptions: return {self.struct_name: self.descriptions} return None @dataclass class Relation: """ A relation extraction definition. Parameters ---------- name : str Name of the relation (e.g., "works_for", "founded"). head : str, optional The source/subject entity. tail : str, optional The target/object entity. **fields : Any Custom field names and values (use instead of head/tail). """ name: str head: Optional[str] = None tail: Optional[str] = None _fields: Dict[str, str] = field(default_factory=dict) def __init__(self, name: str, head: str = None, tail: str = None, **fields): self.name = name self.head = head self.tail = tail if fields: self._fields = fields elif head is not None and tail is not None: self._fields = {"head": head, "tail": tail} else: self._fields = {} if head is not None: self._fields["head"] = head if tail is not None: self._fields["tail"] = tail def validate(self, text: str) -> List[str]: """ Validate this relation. Parameters ---------- text : str The text to validate against. Field values must exist in this text. Returns ------- List[str] List of validation errors. """ errors = [] if not self.name: errors.append("Relation name cannot be empty") if not self._fields: errors.append(f"Relation '{self.name}' has no fields") for field_name, value in self._fields.items(): if isinstance(value, str) and value: if value.lower() not in text.lower(): errors.append(f"Relation value '{value}' for '{self.name}.{field_name}' not found in text") return errors def get_field_names(self) -> List[str]: return list(self._fields.keys()) def to_dict(self) -> Dict[str, Dict[str, str]]: return {self.name: self._fields} @dataclass class InputExample: """ A single training example for GLiNER2. Parameters ---------- text : str The input text for this example. entities : Dict[str, List[str]], optional Entity type to mentions mapping. entity_descriptions : Dict[str, str], optional Descriptions for entity types. classifications : List[Classification], optional Classification tasks for this example. structures : List[Structure], optional Structured data extractions for this example. relations : List[Relation], optional Relation extractions for this example. Examples -------- >>> example = InputExample( ... text="John Smith works at Google.", ... entities={"person": ["John Smith"], "company": ["Google"]} ... ) """ text: str entities: Optional[Dict[str, List[str]]] = None entity_descriptions: Optional[Dict[str, str]] = None classifications: Optional[List[Classification]] = None structures: Optional[List[Structure]] = None relations: Optional[List[Relation]] = None def __post_init__(self): if self.entities is None: self.entities = {} if self.classifications is None: self.classifications = [] if self.structures is None: self.structures = [] if self.relations is None: self.relations = [] def validate(self) -> List[str]: """ Validate this example. Validation is always strict: checks that entity mentions, relation values, and structure field values exist in the text (case-insensitive). Returns ------- List[str] List of validation errors. Empty list means valid. """ errors = [] if not self.text or not self.text.strip(): errors.append("Text cannot be empty") return errors if self.entities: for entity_type, mentions in self.entities.items(): if not entity_type: errors.append("Entity type cannot be empty") for mention in mentions: if mention and mention.lower() not in self.text.lower(): errors.append(f"Entity '{mention}' (type: {entity_type}) not found in text") if self.entity_descriptions and self.entities: for desc_type in self.entity_descriptions: if desc_type not in self.entities: errors.append(f"Entity description for '{desc_type}' but no entities of that type") for cls in self.classifications: errors.extend(cls.validate()) for struct in self.structures: errors.extend(struct.validate(self.text)) relation_fields = {} for rel in self.relations: errors.extend(rel.validate(self.text)) field_names = tuple(sorted(rel.get_field_names())) if rel.name in relation_fields: if relation_fields[rel.name] != field_names: errors.append(f"Relation '{rel.name}' has inconsistent fields: {relation_fields[rel.name]} vs {field_names}") else: relation_fields[rel.name] = field_names has_content = bool(self.entities) or bool(self.classifications) or bool(self.structures) or bool(self.relations) if not has_content: errors.append("Example must have at least one task (entities, classifications, structures, or relations)") return errors def is_valid(self) -> bool: """Check if this example is valid.""" return len(self.validate()) == 0 def sanitize(self) -> Tuple[List[str], bool]: """ Remove invalid parts from this example, keeping only valid content. Mutates self in-place. Granular removal strategy: - Entities: Drop entire entity type if ANY mention is not found in text - Classifications: Drop individual classifications that have errors - Structures: Remove invalid fields; drop structure only if ALL fields become invalid - Relations: Drop the specific relation if ANY field has an error - Example: Mark as invalid only if no valid tasks remain Returns ------- Tuple[List[str], bool] - List of warning messages about what was removed - bool: True if example still has valid content, False if should be dropped """ warnings = [] if not self.text or not self.text.strip(): warnings.append("Text is empty") return warnings, False # 1. Sanitize entities - drop entity type if any mention not found if self.entities: types_to_remove = [] for entity_type, mentions in self.entities.items(): if not entity_type: types_to_remove.append(entity_type) warnings.append(f"Entity type is empty") continue # Check if any mention is not in text has_invalid = False for mention in mentions: if mention and mention.lower() not in self.text.lower(): has_invalid = True warnings.append(f"Entity '{mention}' (type: {entity_type}) not found in text - dropping entity type") break if has_invalid: types_to_remove.append(entity_type) # Remove invalid entity types for entity_type in types_to_remove: del self.entities[entity_type] # Clean up entity descriptions for removed types if self.entity_descriptions: desc_to_remove = [desc_type for desc_type in self.entity_descriptions if desc_type not in self.entities] for desc_type in desc_to_remove: del self.entity_descriptions[desc_type] # 2. Sanitize classifications - drop individual invalid ones if self.classifications: valid_classifications = [] for cls in self.classifications: cls_errors = cls.validate() if cls_errors: warnings.append(f"Classification '{cls.task}' has errors - dropping: {cls_errors[0]}") else: valid_classifications.append(cls) self.classifications = valid_classifications # 3. Sanitize structures - remove invalid fields, drop if all invalid if self.structures: valid_structures = [] for struct in self.structures: if not struct.struct_name: warnings.append(f"Structure has empty name - dropping") continue if not struct._fields: warnings.append(f"Structure '{struct.struct_name}' has no fields - dropping") continue # Filter out invalid fields valid_fields = {} for field_name, value in struct._fields.items(): is_valid = True if isinstance(value, ChoiceField): field_errors = value.validate(f"{struct.struct_name}.{field_name}") if field_errors: warnings.append(f"Field '{struct.struct_name}.{field_name}' invalid - dropping field") is_valid = False elif isinstance(value, list): for v in value: if v and v.lower() not in self.text.lower(): warnings.append(f"List value '{v}' in '{struct.struct_name}.{field_name}' not found - dropping field") is_valid = False break elif isinstance(value, str): if value and value.lower() not in self.text.lower(): warnings.append(f"Value '{value}' for '{struct.struct_name}.{field_name}' not found - dropping field") is_valid = False if is_valid: valid_fields[field_name] = value # Only keep structure if it has at least one valid field if valid_fields: struct._fields = valid_fields valid_structures.append(struct) else: warnings.append(f"Structure '{struct.struct_name}' has no valid fields - dropping") self.structures = valid_structures # 4. Sanitize relations - drop entire relation if any field is invalid if self.relations: valid_relations = [] for rel in self.relations: if not rel.name: warnings.append(f"Relation has empty name - dropping") continue if not rel._fields: warnings.append(f"Relation '{rel.name}' has no fields - dropping") continue # Check if any field value is invalid has_invalid = False for field_name, value in rel._fields.items(): if isinstance(value, str) and value: if value.lower() not in self.text.lower(): warnings.append(f"Relation '{rel.name}' field '{field_name}' value '{value}' not found - dropping relation") has_invalid = True break if not has_invalid: valid_relations.append(rel) self.relations = valid_relations # Check if example still has any valid content has_content = bool(self.entities) or bool(self.classifications) or bool(self.structures) or bool(self.relations) if not has_content: warnings.append("No valid tasks remain after sanitization") return warnings, False return warnings, True def to_dict(self) -> Dict[str, Any]: """Convert to GLiNER2 training format.""" output = {} if self.entities: output["entities"] = self.entities if self.entity_descriptions: output["entity_descriptions"] = self.entity_descriptions if self.classifications: output["classifications"] = [cls.to_dict() for cls in self.classifications] if self.structures: output["json_structures"] = [struct.to_dict() for struct in self.structures] all_descriptions = {} for struct in self.structures: desc = struct.get_descriptions() if desc: all_descriptions.update(desc) if all_descriptions: output["json_descriptions"] = all_descriptions if self.relations: output["relations"] = [rel.to_dict() for rel in self.relations] return {"input": self.text, "output": output} def to_json(self) -> str: return json.dumps(self.to_dict(), ensure_ascii=False) @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'InputExample': """Create InputExample from training format dictionary.""" text = data["input"] output = data["output"] entities = output.get("entities") entity_descriptions = output.get("entity_descriptions") classifications = [] for cls_data in output.get("classifications", []): classifications.append(Classification( task=cls_data["task"], labels=cls_data["labels"], true_label=cls_data["true_label"], multi_label=cls_data.get("multi_label", False), prompt=cls_data.get("prompt"), examples=[tuple(ex) for ex in cls_data.get("examples", [])] or None, label_descriptions=cls_data.get("label_descriptions") )) structures = [] json_descriptions = output.get("json_descriptions", {}) for struct_data in output.get("json_structures", []): for struct_name, fields in struct_data.items(): parsed_fields = {} for field_name, value in fields.items(): if isinstance(value, dict) and "value" in value and "choices" in value: parsed_fields[field_name] = ChoiceField(value["value"], value["choices"]) else: parsed_fields[field_name] = value structures.append(Structure(struct_name, _descriptions=json_descriptions.get(struct_name), **parsed_fields)) relations = [] for rel_data in output.get("relations", []): for rel_name, fields in rel_data.items(): if "head" in fields and "tail" in fields and len(fields) == 2: relations.append(Relation(rel_name, head=fields["head"], tail=fields["tail"])) else: relations.append(Relation(rel_name, **fields)) return cls( text=text, entities=entities, entity_descriptions=entity_descriptions, classifications=classifications if classifications else None, structures=structures if structures else None, relations=relations if relations else None ) @classmethod def from_json(cls, json_str: str) -> 'InputExample': return cls.from_dict(json.loads(json_str)) class TrainingDataset: """ A collection of InputExamples for training GLiNER2. Can be created from: - List of InputExample objects - JSONL file path(s) - Raw dict data Parameters ---------- examples : List[InputExample], optional Initial list of examples. Examples -------- >>> # From InputExample list >>> dataset = TrainingDataset([example1, example2]) >>> >>> # From JSONL file >>> dataset = TrainingDataset.load("train.jsonl") >>> >>> # From multiple JSONL files >>> dataset = TrainingDataset.load(["train1.jsonl", "train2.jsonl"]) """ def __init__(self, examples: List[InputExample] = None): self.examples: List[InputExample] = examples or [] def __len__(self) -> int: return len(self.examples) def __getitem__(self, idx: int) -> InputExample: return self.examples[idx] def __iter__(self) -> Iterator[InputExample]: return iter(self.examples) def add(self, example: InputExample) -> 'TrainingDataset': self.examples.append(example) return self def add_many(self, examples: List[InputExample]) -> 'TrainingDataset': self.examples.extend(examples) return self def validate(self, raise_on_error: bool = True) -> Dict[str, Any]: """ Validate all examples in the dataset. Validation is always strict: checks that entity mentions, relation values, and structure field values exist in the text (case-insensitive). Parameters ---------- raise_on_error : bool, default=True If True, raises ValidationError when invalid examples are found. If False, returns validation report without raising. Returns ------- Dict[str, Any] Validation report with counts and error details. """ all_errors = [] valid_count = 0 invalid_indices = [] for i, example in enumerate(self.examples): errors = example.validate() if errors: invalid_indices.append(i) for error in errors: all_errors.append(f"Example {i}: {error}") else: valid_count += 1 report = { "valid": valid_count, "invalid": len(invalid_indices), "total": len(self.examples), "invalid_indices": invalid_indices, "errors": all_errors } if all_errors and raise_on_error: raise ValidationError(f"Dataset validation failed: {len(invalid_indices)} invalid examples", all_errors) return report def validate_relation_consistency(self) -> List[str]: """Validate that relation field structures are consistent across the dataset.""" errors = [] relation_fields: Dict[str, Tuple[int, Tuple[str, ...]]] = {} for i, example in enumerate(self.examples): for rel in example.relations: field_names = tuple(sorted(rel.get_field_names())) if rel.name in relation_fields: first_idx, first_fields = relation_fields[rel.name] if first_fields != field_names: errors.append(f"Relation '{rel.name}' field inconsistency: Example {first_idx} has {list(first_fields)}, but Example {i} has {list(field_names)}") else: relation_fields[rel.name] = (i, field_names) return errors def stats(self) -> Dict[str, Any]: """Get statistics about the dataset.""" stats = { "total_examples": len(self.examples), "entity_types": Counter(), "entity_mentions": 0, "classification_tasks": Counter(), "classification_labels": {}, "structure_types": Counter(), "relation_types": Counter(), "text_lengths": [], "task_distribution": { "entities_only": 0, "classifications_only": 0, "structures_only": 0, "relations_only": 0, "multi_task": 0, "empty": 0 } } for example in self.examples: stats["text_lengths"].append(len(example.text)) for entity_type, mentions in example.entities.items(): stats["entity_types"][entity_type] += len(mentions) stats["entity_mentions"] += len(mentions) for cls in example.classifications: stats["classification_tasks"][cls.task] += 1 if cls.task not in stats["classification_labels"]: stats["classification_labels"][cls.task] = Counter() for label in cls._true_label_list: stats["classification_labels"][cls.task][label] += 1 for struct in example.structures: stats["structure_types"][struct.struct_name] += 1 for rel in example.relations: stats["relation_types"][rel.name] += 1 has_entities = bool(example.entities) has_cls = bool(example.classifications) has_struct = bool(example.structures) has_rel = bool(example.relations) task_count = sum([has_entities, has_cls, has_struct, has_rel]) if task_count == 0: stats["task_distribution"]["empty"] += 1 elif task_count > 1: stats["task_distribution"]["multi_task"] += 1 elif has_entities: stats["task_distribution"]["entities_only"] += 1 elif has_cls: stats["task_distribution"]["classifications_only"] += 1 elif has_struct: stats["task_distribution"]["structures_only"] += 1 elif has_rel: stats["task_distribution"]["relations_only"] += 1 if stats["text_lengths"]: lengths = stats["text_lengths"] stats["text_length_stats"] = { "min": min(lengths), "max": max(lengths), "mean": sum(lengths) / len(lengths), "median": sorted(lengths)[len(lengths) // 2] } stats["entity_types"] = dict(stats["entity_types"]) stats["classification_tasks"] = dict(stats["classification_tasks"]) stats["classification_labels"] = {k: dict(v) for k, v in stats["classification_labels"].items()} stats["structure_types"] = dict(stats["structure_types"]) stats["relation_types"] = dict(stats["relation_types"]) return stats def print_stats(self): """Print formatted statistics.""" s = self.stats() print(f"\n{'='*60}") print(f"GLiNER2 Training Dataset Statistics") print(f"{'='*60}") print(f"Total examples: {s['total_examples']}") if s.get('text_length_stats'): tls = s['text_length_stats'] print(f"\nText lengths: min={tls['min']}, max={tls['max']}, mean={tls['mean']:.1f}") print(f"\nTask Distribution:") for task, count in s['task_distribution'].items(): if count > 0: print(f" {task}: {count} ({100*count/s['total_examples']:.1f}%)") if s['entity_types']: print(f"\nEntity Types ({s['entity_mentions']} total mentions):") for etype, count in sorted(s['entity_types'].items(), key=lambda x: -x[1]): print(f" {etype}: {count}") if s['classification_tasks']: print(f"\nClassification Tasks:") for task, count in s['classification_tasks'].items(): print(f" {task}: {count} examples") if task in s['classification_labels']: for label, lcount in s['classification_labels'][task].items(): print(f" - {label}: {lcount}") if s['structure_types']: print(f"\nStructure Types:") for stype, count in s['structure_types'].items(): print(f" {stype}: {count}") if s['relation_types']: print(f"\nRelation Types:") for rtype, count in s['relation_types'].items(): print(f" {rtype}: {count}") print(f"{'='*60}\n") def to_jsonl(self) -> str: return "\n".join(example.to_json() for example in self.examples) def to_records(self) -> List[Dict[str, Any]]: """Convert to list of record dicts for trainer.""" return [ex.to_dict() for ex in self.examples] def save(self, path: Union[str, Path], validate_first: bool = True): """Save dataset to JSONL file.""" if validate_first: self.validate() path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with open(path, 'w', encoding='utf-8') as f: for example in self.examples: f.write(example.to_json() + '\n') print(f"Saved {len(self.examples)} examples to {path}") @classmethod def load(cls, paths: Union[str, Path, List[Union[str, Path]]], shuffle: bool = False, seed: int = 42) -> 'TrainingDataset': """ Load dataset from JSONL file(s). Parameters ---------- paths : str, Path, or List Single file path or list of file paths. shuffle : bool, default=False Whether to shuffle the loaded examples. seed : int, default=42 Random seed for shuffling. Returns ------- TrainingDataset """ if isinstance(paths, (str, Path)): paths = [paths] examples = [] for path in paths: path = Path(path) with open(path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): line = line.strip() if line: try: data = json.loads(line) examples.append(InputExample.from_dict(data)) except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON in {path} line {line_num}: {e}") except Exception as e: raise ValueError(f"Error parsing {path} line {line_num}: {e}") print(f"Loaded {len(examples)} examples from {path}") if shuffle: random.seed(seed) random.shuffle(examples) return cls(examples) @classmethod def from_records(cls, records: List[Dict[str, Any]]) -> 'TrainingDataset': """Create dataset from list of record dicts.""" examples = [InputExample.from_dict(r) for r in records] return cls(examples) def split(self, train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, shuffle: bool = True, seed: int = 42) -> Tuple['TrainingDataset', 'TrainingDataset', 'TrainingDataset']: """Split dataset into train/val/test sets.""" if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6: raise ValueError("Ratios must sum to 1.0") indices = list(range(len(self.examples))) if shuffle: random.seed(seed) random.shuffle(indices) n = len(indices) train_end = int(n * train_ratio) val_end = train_end + int(n * val_ratio) return ( TrainingDataset([self.examples[i] for i in indices[:train_end]]), TrainingDataset([self.examples[i] for i in indices[train_end:val_end]]), TrainingDataset([self.examples[i] for i in indices[val_end:]]) ) def filter(self, predicate) -> 'TrainingDataset': """Filter examples based on a predicate function.""" return TrainingDataset([ex for ex in self.examples if predicate(ex)]) def sample(self, n: int, seed: int = 42) -> 'TrainingDataset': """Random sample of examples.""" random.seed(seed) return TrainingDataset(random.sample(self.examples, min(n, len(self.examples)))) # Convenience functions def create_entity_example(text: str, entities: Dict[str, List[str]], descriptions: Dict[str, str] = None) -> InputExample: """Create an entity extraction example.""" return InputExample(text=text, entities=entities, entity_descriptions=descriptions) def create_classification_example(text: str, task: str, labels: List[str], true_label: Union[str, List[str]], multi_label: bool = False, **kwargs) -> InputExample: """Create a classification example.""" return InputExample(text=text, classifications=[Classification(task=task, labels=labels, true_label=true_label, multi_label=multi_label, **kwargs)]) def create_structure_example(text: str, structure_name: str, **fields) -> InputExample: """Create a structured data example.""" return InputExample(text=text, structures=[Structure(structure_name, **fields)]) def create_relation_example(text: str, relation_name: str, head: str = None, tail: str = None, **fields) -> InputExample: """Create a relation extraction example.""" return InputExample(text=text, relations=[Relation(relation_name, head=head, tail=tail, **fields)])