agent-smith/packages/GLiNER2/gliner2/training/data.py
2026-03-06 12:59:32 +01:00

1277 lines
48 KiB
Python

"""
GLiNER2 Training Data Creation & Validation Module
This module provides intuitive classes for creating, validating, and managing
training data for GLiNER2 models.
Quick Examples
--------------
Create entity examples:
>>> example = InputExample(
... text="John works at Google in NYC.",
... entities={"person": ["John"], "company": ["Google"], "location": ["NYC"]}
... )
Create classification examples:
>>> example = InputExample(
... text="This movie is amazing!",
... classifications=[
... Classification(task="sentiment", labels=["positive", "negative"], true_label="positive")
... ]
... )
Create structured data examples:
>>> example = InputExample(
... text="iPhone 15 costs $999",
... structures=[
... Structure("product", name="iPhone 15", price="$999")
... ]
... )
Create relation examples:
>>> example = InputExample(
... text="Elon Musk founded SpaceX.",
... relations=[
... Relation("founded", head="Elon Musk", tail="SpaceX")
... ]
... )
Build and validate dataset:
>>> dataset = TrainingDataset(examples)
>>> dataset.validate() # Raises ValidationError if invalid
>>> dataset.save("train.jsonl")
Load from JSONL:
>>> dataset = TrainingDataset.load("train.jsonl")
>>> # Or load multiple files
>>> dataset = TrainingDataset.load(["train1.jsonl", "train2.jsonl"])
"""
from __future__ import annotations
import json
import random
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Union, Tuple, Iterator, TYPE_CHECKING
from collections import Counter
from tqdm import tqdm
if TYPE_CHECKING:
# Forward declarations for type checking only
pass
class ValidationError(Exception):
"""Raised when training data validation fails."""
def __init__(self, message: str, errors: List[str] = None):
super().__init__(message)
self.errors = errors or []
def __str__(self):
if self.errors:
error_list = "\n - ".join(self.errors[:10])
suffix = f"\n ... and {len(self.errors) - 10} more errors" if len(self.errors) > 10 else ""
return f"{self.args[0]}\n - {error_list}{suffix}"
return self.args[0]
# =============================================================================
# Data Format Detection & Loading
# =============================================================================
class DataFormat:
"""Enum-like class for supported data formats."""
JSONL = "jsonl"
JSONL_LIST = "jsonl_list"
INPUT_EXAMPLE_LIST = "input_example_list"
TRAINING_DATASET = "training_dataset"
DICT_LIST = "dict_list"
EXTRACTOR_DATASET = "extractor_dataset"
def detect_data_format(data: Any) -> str:
"""
Detect the format of input data.
Parameters
----------
data : Any
Input data in any supported format.
Returns
-------
str
The detected format name from DataFormat.
"""
# String path
if isinstance(data, str):
return DataFormat.JSONL
# Path object
if isinstance(data, Path):
return DataFormat.JSONL
# List types
if isinstance(data, list) and len(data) > 0:
first = data[0]
if isinstance(first, (str, Path)):
return DataFormat.JSONL_LIST
if isinstance(first, InputExample):
return DataFormat.INPUT_EXAMPLE_LIST
if isinstance(first, dict):
return DataFormat.DICT_LIST
# Empty list - default to dict list
if isinstance(data, list) and len(data) == 0:
return DataFormat.DICT_LIST
# TrainingDataset
if isinstance(data, TrainingDataset):
return DataFormat.TRAINING_DATASET
# ExtractorDataset (internal) - forward reference
if type(data).__name__ == 'ExtractorDataset':
return DataFormat.EXTRACTOR_DATASET
raise ValueError(f"Unsupported data format: {type(data)}")
class DataLoader_Factory:
"""
Factory for loading data from various formats into a unified internal format.
All loaders convert data to List[Dict] format where each dict has:
- "input": str (the text)
- "output": Dict (the schema/annotations)
Or alternatively:
- "text": str
- "schema": Dict
"""
@staticmethod
def load(
data: Any,
max_samples: int = -1,
shuffle: bool = True,
seed: int = 42,
validate: bool = False,
) -> List[Dict[str, Any]]:
"""
Load data from any supported format.
Parameters
----------
data : Any
Input data in any supported format.
max_samples : int, default=-1
Maximum samples to load (-1 = all).
shuffle : bool, default=True
Whether to shuffle the data.
seed : int, default=42
Random seed for shuffling.
validate : bool, default=False
Whether to validate the data. Validation is always strict:
checks that entity spans, relation values, and structure
field values exist in the text.
Returns
-------
List[Dict[str, Any]]
List of records in unified format.
"""
fmt = detect_data_format(data)
# Load based on format
if fmt == DataFormat.JSONL:
records = DataLoader_Factory._load_jsonl(data)
elif fmt == DataFormat.JSONL_LIST:
records = DataLoader_Factory._load_jsonl_list(data)
elif fmt == DataFormat.INPUT_EXAMPLE_LIST:
records = DataLoader_Factory._load_input_examples(data)
elif fmt == DataFormat.TRAINING_DATASET:
records = DataLoader_Factory._load_training_dataset(data)
elif fmt == DataFormat.DICT_LIST:
records = DataLoader_Factory._load_dict_list(data)
elif fmt == DataFormat.EXTRACTOR_DATASET:
records = data.data.copy()
else:
raise ValueError(f"Unsupported data format: {type(data)}")
# Validate if requested
if validate and records:
valid_indices, invalid_info = DataLoader_Factory._validate_records(records)
if invalid_info:
total_records = len(records)
num_invalid = len(invalid_info)
num_valid = len(valid_indices)
print(f"\nValidation: Found {num_invalid} invalid record(s) out of {total_records} total")
print("Removed invalid records:")
# Print first 5 invalid records
for idx, (record_idx, record, errors) in enumerate(invalid_info[:5]):
# Print first error for this record
error_msg = errors[0] if errors else "Unknown error"
print(f" Record {record_idx}: {error_msg}")
if num_invalid > 5:
print(f" ... and {num_invalid - 5} more invalid record(s)")
print(f"Kept {num_valid} valid record(s)\n")
# Filter records to keep only valid ones
records = [records[i] for i in valid_indices]
# Shuffle
if shuffle and records:
random.seed(seed)
random.shuffle(records)
# Limit samples
if max_samples > 0 and len(records) > max_samples:
records = records[:max_samples]
return records
@staticmethod
def _load_jsonl(path: Union[str, Path]) -> List[Dict]:
"""Load from single JSONL file."""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
records = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if line:
try:
records.append(json.loads(line))
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in {path} line {line_num}: {e}")
return records
@staticmethod
def _load_jsonl_list(paths: List[Union[str, Path]]) -> List[Dict]:
"""Load from multiple JSONL files."""
records = []
for path in paths:
records.extend(DataLoader_Factory._load_jsonl(path))
return records
@staticmethod
def _load_input_examples(examples: List[InputExample]) -> List[Dict]:
"""Load from list of InputExample objects."""
return [ex.to_dict() for ex in examples]
@staticmethod
def _load_training_dataset(dataset: TrainingDataset) -> List[Dict]:
"""Load from TrainingDataset object."""
return dataset.to_records()
@staticmethod
def _load_dict_list(dicts: List[Dict]) -> List[Dict]:
"""Load from list of dicts."""
if not dicts:
return []
first = dicts[0]
# Check format
if "input" in first and "output" in first:
# Already in correct format
return dicts
elif "text" in first and "schema" in first:
# Alternative format - keep as is (handled in __getitem__)
return dicts
elif "text" in first:
# Maybe has entities/classifications at top level - try to convert
records = []
for d in dicts:
output = {}
if "entities" in d:
output["entities"] = d["entities"]
if "classifications" in d:
output["classifications"] = d["classifications"]
if "relations" in d:
output["relations"] = d["relations"]
if "json_structures" in d:
output["json_structures"] = d["json_structures"]
records.append({"input": d["text"], "output": output})
return records
else:
raise ValueError(
f"Unknown dict format. Expected keys like 'input'/'output', 'text'/'schema', "
f"or 'text' with annotation keys. Got: {list(first.keys())}"
)
@staticmethod
def _validate_records(records: List[Dict]) -> Tuple[List[int], List[Tuple[int, Dict, List[str]]]]:
"""
Validate and sanitize all records, removing only invalid parts.
Uses granular sanitization:
- Entities: Drop entity type if any mention not found
- Classifications: Drop individual invalid classifications
- Structures: Remove invalid fields, drop if all invalid
- Relations: Drop relation if any field invalid
- Record: Drop only if no valid tasks remain
Returns
-------
Tuple[List[int], List[Tuple[int, Dict, List[str]]]]
- First element: List of valid record indices (sanitized records replace originals)
- Second element: List of (index, original_record, warning_messages) for dropped records
"""
valid_indices = []
invalid_info = []
for i, record in tqdm(enumerate(records), total=len(records), desc="Validating records", unit="record"):
warnings = []
try:
example = InputExample.from_dict(record)
sanitize_warnings, is_valid = example.sanitize()
if is_valid:
# Replace record with sanitized version
records[i] = example.to_dict()
valid_indices.append(i)
if sanitize_warnings:
# Record was sanitized but still valid
warnings.extend(sanitize_warnings)
else:
# No valid content remains
warnings.extend(sanitize_warnings)
invalid_info.append((i, record, warnings))
except Exception as e:
warnings.append(f"Failed to parse - {e}")
invalid_info.append((i, record, warnings))
return valid_indices, invalid_info
# Type alias for flexible data input
TrainDataInput = Union[
str, # Single JSONL path
Path, # Single JSONL path
List[str], # Multiple JSONL paths
List[Path], # Multiple JSONL paths
List[Dict[str, Any]], # Raw records
'TrainingDataset', # TrainingDataset (forward reference)
'List[InputExample]', # List of InputExample (forward reference)
'ExtractorDataset', # Legacy dataset (forward reference)
]
# =============================================================================
# Training Data Classes
# =============================================================================
@dataclass
class Classification:
"""
A classification task definition.
Parameters
----------
task : str
Name of the classification task (e.g., "sentiment", "category").
labels : List[str]
All possible labels for this task.
true_label : str or List[str]
The correct label(s) for this example.
multi_label : bool, default=False
Whether multiple labels can be selected.
prompt : str, optional
Custom prompt for the task.
examples : List[Tuple[str, str]], optional
Few-shot examples as (input, output) pairs.
label_descriptions : Dict[str, str], optional
Descriptions for each label.
"""
task: str
labels: List[str]
true_label: Union[str, List[str]]
multi_label: bool = False
prompt: Optional[str] = None
examples: Optional[List[Tuple[str, str]]] = None
label_descriptions: Optional[Dict[str, str]] = None
def __post_init__(self):
if isinstance(self.true_label, str):
self._true_label_list = [self.true_label]
else:
self._true_label_list = list(self.true_label)
# Auto-infer multi_label=True when multiple true labels are provided
if len(self._true_label_list) > 1:
self.multi_label = True
def validate(self) -> List[str]:
"""Validate this classification and return list of errors."""
errors = []
if not self.task:
errors.append("Classification task name cannot be empty")
if not self.labels:
errors.append(f"Classification '{self.task}' has no labels")
for label in self._true_label_list:
if label not in self.labels:
errors.append(f"True label '{label}' not in labels list for task '{self.task}'")
if len(self._true_label_list) > 1 and not self.multi_label:
errors.append(f"Multiple true labels provided for '{self.task}' but multi_label=False")
if self.label_descriptions:
for key in self.label_descriptions:
if key not in self.labels:
errors.append(f"Label description key '{key}' not in labels for task '{self.task}'")
if self.examples:
for i, ex in enumerate(self.examples):
if not isinstance(ex, (list, tuple)) or len(ex) != 2:
errors.append(f"Example {i} for task '{self.task}' must be (input, output) pair")
return errors
def to_dict(self) -> Dict[str, Any]:
"""Convert to training format dictionary."""
result = {"task": self.task, "labels": self.labels, "true_label": self._true_label_list}
if self.multi_label:
result["multi_label"] = True
if self.prompt:
result["prompt"] = self.prompt
if self.examples:
result["examples"] = [list(ex) for ex in self.examples]
if self.label_descriptions:
result["label_descriptions"] = self.label_descriptions
return result
@dataclass
class ChoiceField:
"""
A field with predefined choices (classification within structure).
Parameters
----------
value : str
The selected value.
choices : List[str]
All possible choices.
"""
value: str
choices: List[str]
def validate(self, field_name: str) -> List[str]:
errors = []
if self.value not in self.choices:
errors.append(f"Choice value '{self.value}' not in choices {self.choices} for field '{field_name}'")
return errors
def to_dict(self) -> Dict[str, Any]:
return {"value": self.value, "choices": self.choices}
@dataclass
class Structure:
"""
A structured data extraction definition.
Parameters
----------
struct_name : str
Name of the structure (e.g., "product", "contact").
**fields : Any
Field names and values. Values can be:
- str: Single string value
- List[str]: Multiple values
- ChoiceField: Classification-style field with choices
Examples
--------
>>> struct = Structure("product", name="iPhone", price="$999")
>>> struct = Structure("contact", name="John", email="john@example.com")
"""
struct_name: str
_fields: Dict[str, Any] = field(default_factory=dict)
descriptions: Optional[Dict[str, str]] = None
def __init__(self, struct_name: str, _descriptions: Dict[str, str] = None, **fields):
self.struct_name = struct_name
self._fields = fields
self.descriptions = _descriptions
def validate(self, text: str) -> List[str]:
"""
Validate this structure.
Parameters
----------
text : str
The text to validate against. Field values must exist in this text.
Returns
-------
List[str]
List of validation errors.
"""
errors = []
if not self.struct_name:
errors.append("Structure name cannot be empty")
if not self._fields:
errors.append(f"Structure '{self.struct_name}' has no fields")
for field_name, value in self._fields.items():
if isinstance(value, ChoiceField):
errors.extend(value.validate(f"{self.struct_name}.{field_name}"))
elif isinstance(value, list):
for i, v in enumerate(value):
if v and v.lower() not in text.lower():
errors.append(f"List value '{v}' at index {i} in '{self.struct_name}.{field_name}' not found in text")
elif isinstance(value, str):
if value and value.lower() not in text.lower():
errors.append(f"Value '{value}' for '{self.struct_name}.{field_name}' not found in text")
return errors
def to_dict(self) -> Dict[str, Dict[str, Any]]:
fields_dict = {}
for field_name, value in self._fields.items():
if isinstance(value, ChoiceField):
fields_dict[field_name] = value.to_dict()
else:
fields_dict[field_name] = value
return {self.struct_name: fields_dict}
def get_descriptions(self) -> Optional[Dict[str, Dict[str, str]]]:
if self.descriptions:
return {self.struct_name: self.descriptions}
return None
@dataclass
class Relation:
"""
A relation extraction definition.
Parameters
----------
name : str
Name of the relation (e.g., "works_for", "founded").
head : str, optional
The source/subject entity.
tail : str, optional
The target/object entity.
**fields : Any
Custom field names and values (use instead of head/tail).
"""
name: str
head: Optional[str] = None
tail: Optional[str] = None
_fields: Dict[str, str] = field(default_factory=dict)
def __init__(self, name: str, head: str = None, tail: str = None, **fields):
self.name = name
self.head = head
self.tail = tail
if fields:
self._fields = fields
elif head is not None and tail is not None:
self._fields = {"head": head, "tail": tail}
else:
self._fields = {}
if head is not None:
self._fields["head"] = head
if tail is not None:
self._fields["tail"] = tail
def validate(self, text: str) -> List[str]:
"""
Validate this relation.
Parameters
----------
text : str
The text to validate against. Field values must exist in this text.
Returns
-------
List[str]
List of validation errors.
"""
errors = []
if not self.name:
errors.append("Relation name cannot be empty")
if not self._fields:
errors.append(f"Relation '{self.name}' has no fields")
for field_name, value in self._fields.items():
if isinstance(value, str) and value:
if value.lower() not in text.lower():
errors.append(f"Relation value '{value}' for '{self.name}.{field_name}' not found in text")
return errors
def get_field_names(self) -> List[str]:
return list(self._fields.keys())
def to_dict(self) -> Dict[str, Dict[str, str]]:
return {self.name: self._fields}
@dataclass
class InputExample:
"""
A single training example for GLiNER2.
Parameters
----------
text : str
The input text for this example.
entities : Dict[str, List[str]], optional
Entity type to mentions mapping.
entity_descriptions : Dict[str, str], optional
Descriptions for entity types.
classifications : List[Classification], optional
Classification tasks for this example.
structures : List[Structure], optional
Structured data extractions for this example.
relations : List[Relation], optional
Relation extractions for this example.
Examples
--------
>>> example = InputExample(
... text="John Smith works at Google.",
... entities={"person": ["John Smith"], "company": ["Google"]}
... )
"""
text: str
entities: Optional[Dict[str, List[str]]] = None
entity_descriptions: Optional[Dict[str, str]] = None
classifications: Optional[List[Classification]] = None
structures: Optional[List[Structure]] = None
relations: Optional[List[Relation]] = None
def __post_init__(self):
if self.entities is None:
self.entities = {}
if self.classifications is None:
self.classifications = []
if self.structures is None:
self.structures = []
if self.relations is None:
self.relations = []
def validate(self) -> List[str]:
"""
Validate this example.
Validation is always strict: checks that entity mentions, relation values,
and structure field values exist in the text (case-insensitive).
Returns
-------
List[str]
List of validation errors. Empty list means valid.
"""
errors = []
if not self.text or not self.text.strip():
errors.append("Text cannot be empty")
return errors
if self.entities:
for entity_type, mentions in self.entities.items():
if not entity_type:
errors.append("Entity type cannot be empty")
for mention in mentions:
if mention and mention.lower() not in self.text.lower():
errors.append(f"Entity '{mention}' (type: {entity_type}) not found in text")
if self.entity_descriptions and self.entities:
for desc_type in self.entity_descriptions:
if desc_type not in self.entities:
errors.append(f"Entity description for '{desc_type}' but no entities of that type")
for cls in self.classifications:
errors.extend(cls.validate())
for struct in self.structures:
errors.extend(struct.validate(self.text))
relation_fields = {}
for rel in self.relations:
errors.extend(rel.validate(self.text))
field_names = tuple(sorted(rel.get_field_names()))
if rel.name in relation_fields:
if relation_fields[rel.name] != field_names:
errors.append(f"Relation '{rel.name}' has inconsistent fields: {relation_fields[rel.name]} vs {field_names}")
else:
relation_fields[rel.name] = field_names
has_content = bool(self.entities) or bool(self.classifications) or bool(self.structures) or bool(self.relations)
if not has_content:
errors.append("Example must have at least one task (entities, classifications, structures, or relations)")
return errors
def is_valid(self) -> bool:
"""Check if this example is valid."""
return len(self.validate()) == 0
def sanitize(self) -> Tuple[List[str], bool]:
"""
Remove invalid parts from this example, keeping only valid content.
Mutates self in-place.
Granular removal strategy:
- Entities: Drop entire entity type if ANY mention is not found in text
- Classifications: Drop individual classifications that have errors
- Structures: Remove invalid fields; drop structure only if ALL fields become invalid
- Relations: Drop the specific relation if ANY field has an error
- Example: Mark as invalid only if no valid tasks remain
Returns
-------
Tuple[List[str], bool]
- List of warning messages about what was removed
- bool: True if example still has valid content, False if should be dropped
"""
warnings = []
if not self.text or not self.text.strip():
warnings.append("Text is empty")
return warnings, False
# 1. Sanitize entities - drop entity type if any mention not found
if self.entities:
types_to_remove = []
for entity_type, mentions in self.entities.items():
if not entity_type:
types_to_remove.append(entity_type)
warnings.append(f"Entity type is empty")
continue
# Check if any mention is not in text
has_invalid = False
for mention in mentions:
if mention and mention.lower() not in self.text.lower():
has_invalid = True
warnings.append(f"Entity '{mention}' (type: {entity_type}) not found in text - dropping entity type")
break
if has_invalid:
types_to_remove.append(entity_type)
# Remove invalid entity types
for entity_type in types_to_remove:
del self.entities[entity_type]
# Clean up entity descriptions for removed types
if self.entity_descriptions:
desc_to_remove = [desc_type for desc_type in self.entity_descriptions if desc_type not in self.entities]
for desc_type in desc_to_remove:
del self.entity_descriptions[desc_type]
# 2. Sanitize classifications - drop individual invalid ones
if self.classifications:
valid_classifications = []
for cls in self.classifications:
cls_errors = cls.validate()
if cls_errors:
warnings.append(f"Classification '{cls.task}' has errors - dropping: {cls_errors[0]}")
else:
valid_classifications.append(cls)
self.classifications = valid_classifications
# 3. Sanitize structures - remove invalid fields, drop if all invalid
if self.structures:
valid_structures = []
for struct in self.structures:
if not struct.struct_name:
warnings.append(f"Structure has empty name - dropping")
continue
if not struct._fields:
warnings.append(f"Structure '{struct.struct_name}' has no fields - dropping")
continue
# Filter out invalid fields
valid_fields = {}
for field_name, value in struct._fields.items():
is_valid = True
if isinstance(value, ChoiceField):
field_errors = value.validate(f"{struct.struct_name}.{field_name}")
if field_errors:
warnings.append(f"Field '{struct.struct_name}.{field_name}' invalid - dropping field")
is_valid = False
elif isinstance(value, list):
for v in value:
if v and v.lower() not in self.text.lower():
warnings.append(f"List value '{v}' in '{struct.struct_name}.{field_name}' not found - dropping field")
is_valid = False
break
elif isinstance(value, str):
if value and value.lower() not in self.text.lower():
warnings.append(f"Value '{value}' for '{struct.struct_name}.{field_name}' not found - dropping field")
is_valid = False
if is_valid:
valid_fields[field_name] = value
# Only keep structure if it has at least one valid field
if valid_fields:
struct._fields = valid_fields
valid_structures.append(struct)
else:
warnings.append(f"Structure '{struct.struct_name}' has no valid fields - dropping")
self.structures = valid_structures
# 4. Sanitize relations - drop entire relation if any field is invalid
if self.relations:
valid_relations = []
for rel in self.relations:
if not rel.name:
warnings.append(f"Relation has empty name - dropping")
continue
if not rel._fields:
warnings.append(f"Relation '{rel.name}' has no fields - dropping")
continue
# Check if any field value is invalid
has_invalid = False
for field_name, value in rel._fields.items():
if isinstance(value, str) and value:
if value.lower() not in self.text.lower():
warnings.append(f"Relation '{rel.name}' field '{field_name}' value '{value}' not found - dropping relation")
has_invalid = True
break
if not has_invalid:
valid_relations.append(rel)
self.relations = valid_relations
# Check if example still has any valid content
has_content = bool(self.entities) or bool(self.classifications) or bool(self.structures) or bool(self.relations)
if not has_content:
warnings.append("No valid tasks remain after sanitization")
return warnings, False
return warnings, True
def to_dict(self) -> Dict[str, Any]:
"""Convert to GLiNER2 training format."""
output = {}
if self.entities:
output["entities"] = self.entities
if self.entity_descriptions:
output["entity_descriptions"] = self.entity_descriptions
if self.classifications:
output["classifications"] = [cls.to_dict() for cls in self.classifications]
if self.structures:
output["json_structures"] = [struct.to_dict() for struct in self.structures]
all_descriptions = {}
for struct in self.structures:
desc = struct.get_descriptions()
if desc:
all_descriptions.update(desc)
if all_descriptions:
output["json_descriptions"] = all_descriptions
if self.relations:
output["relations"] = [rel.to_dict() for rel in self.relations]
return {"input": self.text, "output": output}
def to_json(self) -> str:
return json.dumps(self.to_dict(), ensure_ascii=False)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'InputExample':
"""Create InputExample from training format dictionary."""
text = data["input"]
output = data["output"]
entities = output.get("entities")
entity_descriptions = output.get("entity_descriptions")
classifications = []
for cls_data in output.get("classifications", []):
classifications.append(Classification(
task=cls_data["task"],
labels=cls_data["labels"],
true_label=cls_data["true_label"],
multi_label=cls_data.get("multi_label", False),
prompt=cls_data.get("prompt"),
examples=[tuple(ex) for ex in cls_data.get("examples", [])] or None,
label_descriptions=cls_data.get("label_descriptions")
))
structures = []
json_descriptions = output.get("json_descriptions", {})
for struct_data in output.get("json_structures", []):
for struct_name, fields in struct_data.items():
parsed_fields = {}
for field_name, value in fields.items():
if isinstance(value, dict) and "value" in value and "choices" in value:
parsed_fields[field_name] = ChoiceField(value["value"], value["choices"])
else:
parsed_fields[field_name] = value
structures.append(Structure(struct_name, _descriptions=json_descriptions.get(struct_name), **parsed_fields))
relations = []
for rel_data in output.get("relations", []):
for rel_name, fields in rel_data.items():
if "head" in fields and "tail" in fields and len(fields) == 2:
relations.append(Relation(rel_name, head=fields["head"], tail=fields["tail"]))
else:
relations.append(Relation(rel_name, **fields))
return cls(
text=text,
entities=entities,
entity_descriptions=entity_descriptions,
classifications=classifications if classifications else None,
structures=structures if structures else None,
relations=relations if relations else None
)
@classmethod
def from_json(cls, json_str: str) -> 'InputExample':
return cls.from_dict(json.loads(json_str))
class TrainingDataset:
"""
A collection of InputExamples for training GLiNER2.
Can be created from:
- List of InputExample objects
- JSONL file path(s)
- Raw dict data
Parameters
----------
examples : List[InputExample], optional
Initial list of examples.
Examples
--------
>>> # From InputExample list
>>> dataset = TrainingDataset([example1, example2])
>>>
>>> # From JSONL file
>>> dataset = TrainingDataset.load("train.jsonl")
>>>
>>> # From multiple JSONL files
>>> dataset = TrainingDataset.load(["train1.jsonl", "train2.jsonl"])
"""
def __init__(self, examples: List[InputExample] = None):
self.examples: List[InputExample] = examples or []
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, idx: int) -> InputExample:
return self.examples[idx]
def __iter__(self) -> Iterator[InputExample]:
return iter(self.examples)
def add(self, example: InputExample) -> 'TrainingDataset':
self.examples.append(example)
return self
def add_many(self, examples: List[InputExample]) -> 'TrainingDataset':
self.examples.extend(examples)
return self
def validate(self, raise_on_error: bool = True) -> Dict[str, Any]:
"""
Validate all examples in the dataset.
Validation is always strict: checks that entity mentions, relation values,
and structure field values exist in the text (case-insensitive).
Parameters
----------
raise_on_error : bool, default=True
If True, raises ValidationError when invalid examples are found.
If False, returns validation report without raising.
Returns
-------
Dict[str, Any]
Validation report with counts and error details.
"""
all_errors = []
valid_count = 0
invalid_indices = []
for i, example in enumerate(self.examples):
errors = example.validate()
if errors:
invalid_indices.append(i)
for error in errors:
all_errors.append(f"Example {i}: {error}")
else:
valid_count += 1
report = {
"valid": valid_count,
"invalid": len(invalid_indices),
"total": len(self.examples),
"invalid_indices": invalid_indices,
"errors": all_errors
}
if all_errors and raise_on_error:
raise ValidationError(f"Dataset validation failed: {len(invalid_indices)} invalid examples", all_errors)
return report
def validate_relation_consistency(self) -> List[str]:
"""Validate that relation field structures are consistent across the dataset."""
errors = []
relation_fields: Dict[str, Tuple[int, Tuple[str, ...]]] = {}
for i, example in enumerate(self.examples):
for rel in example.relations:
field_names = tuple(sorted(rel.get_field_names()))
if rel.name in relation_fields:
first_idx, first_fields = relation_fields[rel.name]
if first_fields != field_names:
errors.append(f"Relation '{rel.name}' field inconsistency: Example {first_idx} has {list(first_fields)}, but Example {i} has {list(field_names)}")
else:
relation_fields[rel.name] = (i, field_names)
return errors
def stats(self) -> Dict[str, Any]:
"""Get statistics about the dataset."""
stats = {
"total_examples": len(self.examples),
"entity_types": Counter(),
"entity_mentions": 0,
"classification_tasks": Counter(),
"classification_labels": {},
"structure_types": Counter(),
"relation_types": Counter(),
"text_lengths": [],
"task_distribution": {
"entities_only": 0, "classifications_only": 0, "structures_only": 0,
"relations_only": 0, "multi_task": 0, "empty": 0
}
}
for example in self.examples:
stats["text_lengths"].append(len(example.text))
for entity_type, mentions in example.entities.items():
stats["entity_types"][entity_type] += len(mentions)
stats["entity_mentions"] += len(mentions)
for cls in example.classifications:
stats["classification_tasks"][cls.task] += 1
if cls.task not in stats["classification_labels"]:
stats["classification_labels"][cls.task] = Counter()
for label in cls._true_label_list:
stats["classification_labels"][cls.task][label] += 1
for struct in example.structures:
stats["structure_types"][struct.struct_name] += 1
for rel in example.relations:
stats["relation_types"][rel.name] += 1
has_entities = bool(example.entities)
has_cls = bool(example.classifications)
has_struct = bool(example.structures)
has_rel = bool(example.relations)
task_count = sum([has_entities, has_cls, has_struct, has_rel])
if task_count == 0:
stats["task_distribution"]["empty"] += 1
elif task_count > 1:
stats["task_distribution"]["multi_task"] += 1
elif has_entities:
stats["task_distribution"]["entities_only"] += 1
elif has_cls:
stats["task_distribution"]["classifications_only"] += 1
elif has_struct:
stats["task_distribution"]["structures_only"] += 1
elif has_rel:
stats["task_distribution"]["relations_only"] += 1
if stats["text_lengths"]:
lengths = stats["text_lengths"]
stats["text_length_stats"] = {
"min": min(lengths), "max": max(lengths),
"mean": sum(lengths) / len(lengths),
"median": sorted(lengths)[len(lengths) // 2]
}
stats["entity_types"] = dict(stats["entity_types"])
stats["classification_tasks"] = dict(stats["classification_tasks"])
stats["classification_labels"] = {k: dict(v) for k, v in stats["classification_labels"].items()}
stats["structure_types"] = dict(stats["structure_types"])
stats["relation_types"] = dict(stats["relation_types"])
return stats
def print_stats(self):
"""Print formatted statistics."""
s = self.stats()
print(f"\n{'='*60}")
print(f"GLiNER2 Training Dataset Statistics")
print(f"{'='*60}")
print(f"Total examples: {s['total_examples']}")
if s.get('text_length_stats'):
tls = s['text_length_stats']
print(f"\nText lengths: min={tls['min']}, max={tls['max']}, mean={tls['mean']:.1f}")
print(f"\nTask Distribution:")
for task, count in s['task_distribution'].items():
if count > 0:
print(f" {task}: {count} ({100*count/s['total_examples']:.1f}%)")
if s['entity_types']:
print(f"\nEntity Types ({s['entity_mentions']} total mentions):")
for etype, count in sorted(s['entity_types'].items(), key=lambda x: -x[1]):
print(f" {etype}: {count}")
if s['classification_tasks']:
print(f"\nClassification Tasks:")
for task, count in s['classification_tasks'].items():
print(f" {task}: {count} examples")
if task in s['classification_labels']:
for label, lcount in s['classification_labels'][task].items():
print(f" - {label}: {lcount}")
if s['structure_types']:
print(f"\nStructure Types:")
for stype, count in s['structure_types'].items():
print(f" {stype}: {count}")
if s['relation_types']:
print(f"\nRelation Types:")
for rtype, count in s['relation_types'].items():
print(f" {rtype}: {count}")
print(f"{'='*60}\n")
def to_jsonl(self) -> str:
return "\n".join(example.to_json() for example in self.examples)
def to_records(self) -> List[Dict[str, Any]]:
"""Convert to list of record dicts for trainer."""
return [ex.to_dict() for ex in self.examples]
def save(self, path: Union[str, Path], validate_first: bool = True):
"""Save dataset to JSONL file."""
if validate_first:
self.validate()
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, 'w', encoding='utf-8') as f:
for example in self.examples:
f.write(example.to_json() + '\n')
print(f"Saved {len(self.examples)} examples to {path}")
@classmethod
def load(cls, paths: Union[str, Path, List[Union[str, Path]]], shuffle: bool = False, seed: int = 42) -> 'TrainingDataset':
"""
Load dataset from JSONL file(s).
Parameters
----------
paths : str, Path, or List
Single file path or list of file paths.
shuffle : bool, default=False
Whether to shuffle the loaded examples.
seed : int, default=42
Random seed for shuffling.
Returns
-------
TrainingDataset
"""
if isinstance(paths, (str, Path)):
paths = [paths]
examples = []
for path in paths:
path = Path(path)
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if line:
try:
data = json.loads(line)
examples.append(InputExample.from_dict(data))
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in {path} line {line_num}: {e}")
except Exception as e:
raise ValueError(f"Error parsing {path} line {line_num}: {e}")
print(f"Loaded {len(examples)} examples from {path}")
if shuffle:
random.seed(seed)
random.shuffle(examples)
return cls(examples)
@classmethod
def from_records(cls, records: List[Dict[str, Any]]) -> 'TrainingDataset':
"""Create dataset from list of record dicts."""
examples = [InputExample.from_dict(r) for r in records]
return cls(examples)
def split(self, train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1,
shuffle: bool = True, seed: int = 42) -> Tuple['TrainingDataset', 'TrainingDataset', 'TrainingDataset']:
"""Split dataset into train/val/test sets."""
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
raise ValueError("Ratios must sum to 1.0")
indices = list(range(len(self.examples)))
if shuffle:
random.seed(seed)
random.shuffle(indices)
n = len(indices)
train_end = int(n * train_ratio)
val_end = train_end + int(n * val_ratio)
return (
TrainingDataset([self.examples[i] for i in indices[:train_end]]),
TrainingDataset([self.examples[i] for i in indices[train_end:val_end]]),
TrainingDataset([self.examples[i] for i in indices[val_end:]])
)
def filter(self, predicate) -> 'TrainingDataset':
"""Filter examples based on a predicate function."""
return TrainingDataset([ex for ex in self.examples if predicate(ex)])
def sample(self, n: int, seed: int = 42) -> 'TrainingDataset':
"""Random sample of examples."""
random.seed(seed)
return TrainingDataset(random.sample(self.examples, min(n, len(self.examples))))
# Convenience functions
def create_entity_example(text: str, entities: Dict[str, List[str]], descriptions: Dict[str, str] = None) -> InputExample:
"""Create an entity extraction example."""
return InputExample(text=text, entities=entities, entity_descriptions=descriptions)
def create_classification_example(text: str, task: str, labels: List[str], true_label: Union[str, List[str]],
multi_label: bool = False, **kwargs) -> InputExample:
"""Create a classification example."""
return InputExample(text=text, classifications=[Classification(task=task, labels=labels, true_label=true_label, multi_label=multi_label, **kwargs)])
def create_structure_example(text: str, structure_name: str, **fields) -> InputExample:
"""Create a structured data example."""
return InputExample(text=text, structures=[Structure(structure_name, **fields)])
def create_relation_example(text: str, relation_name: str, head: str = None, tail: str = None, **fields) -> InputExample:
"""Create a relation extraction example."""
return InputExample(text=text, relations=[Relation(relation_name, head=head, tail=tail, **fields)])