1072 lines
39 KiB
Python
1072 lines
39 KiB
Python
"""
|
|
GLiNER2 Schema Transformer with Optimized Batch Processing
|
|
|
|
This module handles all preprocessing for GLiNER2, with efficient batching
|
|
via DataLoader collate functions for parallel preprocessing.
|
|
"""
|
|
|
|
import copy
|
|
import random
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, Tuple, Iterator, List
|
|
import torch
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
# =============================================================================
|
|
# Data Structures
|
|
# =============================================================================
|
|
|
|
@dataclass
|
|
class TransformedRecord:
|
|
"""Single transformed record ready for batching."""
|
|
input_ids: List[int]
|
|
mapped_indices: List[Tuple[str, int, int]]
|
|
schema_tokens_list: List[List[str]]
|
|
text_tokens: List[str]
|
|
structure_labels: List[Any]
|
|
task_types: List[str]
|
|
start_token_idx: List[int]
|
|
end_token_idx: List[int]
|
|
text: str
|
|
schema: Dict[str, Any]
|
|
num_schemas: int = field(init=False)
|
|
|
|
def __post_init__(self):
|
|
self.num_schemas = len(self.schema_tokens_list)
|
|
|
|
|
|
@dataclass
|
|
class PreprocessedBatch:
|
|
"""GPU-ready batch for training/inference."""
|
|
input_ids: torch.Tensor # (batch, max_seq_len)
|
|
attention_mask: torch.Tensor # (batch, max_seq_len)
|
|
mapped_indices: List[List[Tuple]] # Per-sample token mappings
|
|
schema_counts: List[int] # Number of schemas per sample
|
|
original_lengths: List[int] # Original sequence lengths
|
|
structure_labels: List[List[Any]] # Ground truth labels
|
|
task_types: List[List[str]] # Task types per schema
|
|
text_tokens: List[List[str]] # Original text tokens
|
|
schema_tokens_list: List[List[List[str]]] # Schema tokens per sample
|
|
start_mappings: List[List[int]] # Char position start mappings
|
|
end_mappings: List[List[int]] # Char position end mappings
|
|
original_texts: List[str] # For result formatting
|
|
original_schemas: List[Dict] # For result formatting
|
|
|
|
def to(self, device: torch.device) -> 'PreprocessedBatch':
|
|
"""Move tensors to device."""
|
|
return PreprocessedBatch(
|
|
input_ids=self.input_ids.to(device),
|
|
attention_mask=self.attention_mask.to(device),
|
|
mapped_indices=self.mapped_indices,
|
|
schema_counts=self.schema_counts,
|
|
original_lengths=self.original_lengths,
|
|
structure_labels=self.structure_labels,
|
|
task_types=self.task_types,
|
|
text_tokens=self.text_tokens,
|
|
schema_tokens_list=self.schema_tokens_list,
|
|
start_mappings=self.start_mappings,
|
|
end_mappings=self.end_mappings,
|
|
original_texts=self.original_texts,
|
|
original_schemas=self.original_schemas,
|
|
)
|
|
|
|
def pin_memory(self) -> 'PreprocessedBatch':
|
|
"""Pin tensors to memory for faster GPU transfer."""
|
|
return PreprocessedBatch(
|
|
input_ids=self.input_ids.pin_memory(),
|
|
attention_mask=self.attention_mask.pin_memory(),
|
|
mapped_indices=self.mapped_indices,
|
|
schema_counts=self.schema_counts,
|
|
original_lengths=self.original_lengths,
|
|
structure_labels=self.structure_labels,
|
|
task_types=self.task_types,
|
|
text_tokens=self.text_tokens,
|
|
schema_tokens_list=self.schema_tokens_list,
|
|
start_mappings=self.start_mappings,
|
|
end_mappings=self.end_mappings,
|
|
original_texts=self.original_texts,
|
|
original_schemas=self.original_schemas,
|
|
)
|
|
|
|
def __contains__(self, key: str) -> bool:
|
|
"""Check if key is a field name. Required for HuggingFace Trainer compatibility."""
|
|
return hasattr(self, key)
|
|
|
|
def __iter__(self):
|
|
"""Iterate over field names. Required for HuggingFace Trainer compatibility."""
|
|
return iter(self.__dataclass_fields__.keys())
|
|
|
|
def __getitem__(self, key):
|
|
"""Get field by name. Required for HuggingFace Trainer compatibility."""
|
|
if isinstance(key, str):
|
|
return getattr(self, key)
|
|
raise KeyError(f"PreprocessedBatch does not support integer indexing: {key}")
|
|
|
|
def __len__(self) -> int:
|
|
return self.input_ids.shape[0]
|
|
|
|
|
|
# =============================================================================
|
|
# Tokenizer
|
|
# =============================================================================
|
|
|
|
class WhitespaceTokenSplitter:
|
|
"""Fast regex-based tokenizer for text splitting."""
|
|
__slots__ = ()
|
|
|
|
_PATTERN = re.compile(
|
|
r"""(?:https?://[^\s]+|www\.[^\s]+)
|
|
|[a-z0-9._%+-]+@[a-z0-9.-]+\.[a-z]{2,}
|
|
|@[a-z0-9_]+
|
|
|\w+(?:[-_]\w+)*
|
|
|\S""",
|
|
re.VERBOSE | re.IGNORECASE,
|
|
)
|
|
|
|
def __call__(self, text: str, lower: bool = True) -> Iterator[Tuple[str, int, int]]:
|
|
if lower:
|
|
text = text.lower()
|
|
for m in self._PATTERN.finditer(text):
|
|
yield m.group(), m.start(), m.end()
|
|
|
|
|
|
# =============================================================================
|
|
# Sampling Configuration
|
|
# =============================================================================
|
|
|
|
@dataclass
|
|
class SamplingConfig:
|
|
"""Configuration for stochastic sampling during training."""
|
|
# JSON Structures
|
|
remove_json_structure_prob: float = 0.2
|
|
shuffle_json_fields: bool = True
|
|
remove_json_field_prob: float = 0.2
|
|
# Entities
|
|
remove_entities_prob: float = 0.0
|
|
shuffle_entities: bool = False
|
|
remove_entity_prob: float = 0.0
|
|
synthetic_entity_label_prob: float = 0.2
|
|
# Relations
|
|
remove_relations_prob: float = 0.2
|
|
swap_head_tail_prob: float = 0.2
|
|
# Classifications
|
|
remove_classification_prob: float = 0.0
|
|
shuffle_classification_labels: bool = True
|
|
remove_classification_label_prob: float = 0.5
|
|
synthetic_label_prob: float = 0.5
|
|
include_true_label_prob: float = 0.5
|
|
max_num_labels: int = 1000
|
|
|
|
|
|
# =============================================================================
|
|
# Main Processor Class
|
|
# =============================================================================
|
|
|
|
class SchemaTransformer:
|
|
"""
|
|
Schema-based text transformer for GLiNER2.
|
|
|
|
Provides efficient batch preprocessing via collate functions
|
|
for parallel DataLoader preprocessing.
|
|
"""
|
|
|
|
# Special tokens
|
|
SEP_STRUCT = "[SEP_STRUCT]"
|
|
SEP_TEXT = "[SEP_TEXT]"
|
|
P_TOKEN = "[P]"
|
|
C_TOKEN = "[C]"
|
|
E_TOKEN = "[E]"
|
|
R_TOKEN = "[R]"
|
|
L_TOKEN = "[L]"
|
|
EXAMPLE_TOKEN = "[EXAMPLE]"
|
|
OUTPUT_TOKEN = "[OUTPUT]"
|
|
DESC_TOKEN = "[DESCRIPTION]"
|
|
|
|
SPECIAL_TOKENS = [
|
|
SEP_STRUCT, SEP_TEXT, P_TOKEN, C_TOKEN, E_TOKEN,
|
|
R_TOKEN, L_TOKEN, EXAMPLE_TOKEN, OUTPUT_TOKEN, DESC_TOKEN
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = None,
|
|
tokenizer=None,
|
|
sampling_config: SamplingConfig = None,
|
|
token_pooling: str = "first"
|
|
):
|
|
if model_name is None and tokenizer is None:
|
|
raise ValueError("Either model_name or tokenizer must be provided.")
|
|
|
|
self.token_pooling = token_pooling if token_pooling in ["first", "mean", "max"] else "first"
|
|
self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_name)
|
|
self.word_splitter = WhitespaceTokenSplitter()
|
|
self.sampling_config = sampling_config or SamplingConfig()
|
|
self.is_training = False
|
|
|
|
# Add special tokens
|
|
self.tokenizer.add_special_tokens({
|
|
"additional_special_tokens": self.SPECIAL_TOKENS
|
|
})
|
|
|
|
# OPT-1: Pre-compute special token IDs for fast lookup in embedding extraction
|
|
self._special_ids = frozenset(
|
|
self.tokenizer.convert_tokens_to_ids(t)
|
|
for t in (self.P_TOKEN, self.C_TOKEN, self.E_TOKEN, self.R_TOKEN, self.L_TOKEN)
|
|
)
|
|
|
|
# OPT-6: Cache tokenized forms of special tokens and common punctuation
|
|
self._token_cache = {}
|
|
for tok in self.SPECIAL_TOKENS + ["(", ")", ",", "|"]:
|
|
self._token_cache[tok] = self.tokenizer.tokenize(tok)
|
|
|
|
def change_mode(self, is_training: bool):
|
|
"""Switch between training and inference mode."""
|
|
self.is_training = is_training
|
|
|
|
# =========================================================================
|
|
# Main Public API: Collate Functions
|
|
# =========================================================================
|
|
|
|
def collate_fn_train(
|
|
self,
|
|
batch: List[Tuple[str, Dict]]
|
|
) -> PreprocessedBatch:
|
|
"""
|
|
Collate function for training DataLoader.
|
|
|
|
Use this with DataLoader for parallel preprocessing:
|
|
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=32,
|
|
collate_fn=processor.collate_fn_train,
|
|
num_workers=8
|
|
)
|
|
|
|
Args:
|
|
batch: List of (text, schema) tuples from dataset
|
|
|
|
Returns:
|
|
PreprocessedBatch ready for model.forward()
|
|
"""
|
|
self.is_training = True
|
|
return self._collate_batch(batch)
|
|
|
|
def collate_fn_inference(
|
|
self,
|
|
batch: List[Tuple[str, Any]]
|
|
) -> PreprocessedBatch:
|
|
"""
|
|
Collate function for inference DataLoader.
|
|
|
|
Args:
|
|
batch: List of (text, schema) tuples
|
|
|
|
Returns:
|
|
PreprocessedBatch for batch_extract
|
|
"""
|
|
self.is_training = False
|
|
return self._collate_batch(batch)
|
|
|
|
def transform_and_format(
|
|
self,
|
|
text: str,
|
|
schema: Dict[str, Any]
|
|
) -> TransformedRecord:
|
|
"""
|
|
Transform and format a single record.
|
|
|
|
This is the main preprocessing entry point for single records.
|
|
For batch processing, use collate_fn_train/collate_fn_inference.
|
|
|
|
Args:
|
|
text: Input text
|
|
schema: Schema dictionary
|
|
|
|
Returns:
|
|
TransformedRecord ready for batching
|
|
"""
|
|
record = {"text": text, "schema": copy.deepcopy(schema)}
|
|
return self._transform_record(record)
|
|
|
|
# =========================================================================
|
|
# Internal: Batch Processing
|
|
# =========================================================================
|
|
|
|
def _collate_batch(
|
|
self,
|
|
batch: List[Tuple[str, Any]]
|
|
) -> PreprocessedBatch:
|
|
"""Internal collate implementation."""
|
|
transformed_records = []
|
|
|
|
for text, schema in batch:
|
|
# Handle Schema objects
|
|
if hasattr(schema, 'build'):
|
|
schema = schema.build()
|
|
elif hasattr(schema, 'schema'):
|
|
schema = schema.schema
|
|
|
|
# Ensure text ends with punctuation
|
|
if text and not text.endswith(('.', '!', '?')):
|
|
text = text + "."
|
|
elif not text:
|
|
text = "."
|
|
|
|
record = {"text": text, "schema": copy.deepcopy(schema)}
|
|
|
|
try:
|
|
transformed = self._transform_record(record)
|
|
transformed_records.append(transformed)
|
|
except Exception as e:
|
|
# Create minimal fallback record
|
|
transformed_records.append(self._create_fallback_record(text, schema))
|
|
|
|
return self._pad_batch(transformed_records)
|
|
|
|
def _transform_record(self, record: Dict[str, Any]) -> TransformedRecord:
|
|
"""Transform a single record (internal)."""
|
|
# OPT-4: Caller (_collate_batch) already deepcopies the schema.
|
|
# Only deepcopy here for direct callers (transform_and_format).
|
|
text, schema = record["text"], record["schema"]
|
|
|
|
# Build classification prefix
|
|
prefix = self._build_classification_prefix(schema)
|
|
|
|
# Save a copy of the original schema BEFORE wrapping modifies it
|
|
# This preserves choice field info for extraction
|
|
original_schema = copy.deepcopy(schema)
|
|
|
|
# Handle classification field wrapping
|
|
if prefix:
|
|
self._wrap_classification_fields(schema, prefix)
|
|
|
|
# Tokenize text
|
|
text_tokens = []
|
|
start_idx_map = []
|
|
end_idx_map = []
|
|
for tkn, start, end in self.word_splitter(text, lower=True):
|
|
text_tokens.append(tkn)
|
|
start_idx_map.append(start)
|
|
end_idx_map.append(end)
|
|
|
|
if prefix:
|
|
text_tokens = prefix + text_tokens
|
|
len_prefix = len(prefix)
|
|
|
|
# Infer schema
|
|
processed = self._infer_from_json(schema)
|
|
|
|
# Build outputs
|
|
results = self._build_outputs(
|
|
processed, schema, text_tokens, len_prefix
|
|
)
|
|
|
|
# Format input
|
|
schema_tokens_list = [r["schema_tokens"] for r in results]
|
|
format_result = self._format_input_with_mapping(schema_tokens_list, text_tokens)
|
|
|
|
return TransformedRecord(
|
|
input_ids=format_result["input_ids"],
|
|
mapped_indices=format_result["mapped_indices"],
|
|
schema_tokens_list=schema_tokens_list,
|
|
text_tokens=text_tokens,
|
|
structure_labels=[r["output"] for r in results],
|
|
task_types=[r["task_type"] for r in results],
|
|
start_token_idx=start_idx_map,
|
|
end_token_idx=end_idx_map,
|
|
text=text,
|
|
schema=original_schema, # Use original schema with choice info preserved
|
|
)
|
|
|
|
def _pad_batch(
|
|
self,
|
|
records: List[TransformedRecord]
|
|
) -> PreprocessedBatch:
|
|
"""Pad transformed records into a batch."""
|
|
if not records:
|
|
return self._empty_batch()
|
|
|
|
max_len = max(len(r.input_ids) for r in records)
|
|
batch_size = len(records)
|
|
|
|
# Pre-allocate tensors
|
|
input_ids = torch.zeros((batch_size, max_len), dtype=torch.long)
|
|
attention_mask = torch.zeros((batch_size, max_len), dtype=torch.long)
|
|
original_lengths = []
|
|
|
|
for i, rec in enumerate(records):
|
|
seq_len = len(rec.input_ids)
|
|
input_ids[i, :seq_len] = torch.tensor(rec.input_ids, dtype=torch.long)
|
|
attention_mask[i, :seq_len] = 1
|
|
original_lengths.append(seq_len)
|
|
|
|
return PreprocessedBatch(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
mapped_indices=[r.mapped_indices for r in records],
|
|
schema_counts=[r.num_schemas for r in records],
|
|
original_lengths=original_lengths,
|
|
structure_labels=[r.structure_labels for r in records],
|
|
task_types=[r.task_types for r in records],
|
|
text_tokens=[r.text_tokens for r in records],
|
|
schema_tokens_list=[r.schema_tokens_list for r in records],
|
|
start_mappings=[r.start_token_idx for r in records],
|
|
end_mappings=[r.end_token_idx for r in records],
|
|
original_texts=[r.text for r in records],
|
|
original_schemas=[r.schema for r in records],
|
|
)
|
|
|
|
def _empty_batch(self) -> PreprocessedBatch:
|
|
"""Create empty batch for edge cases."""
|
|
return PreprocessedBatch(
|
|
input_ids=torch.zeros((0, 0), dtype=torch.long),
|
|
attention_mask=torch.zeros((0, 0), dtype=torch.long),
|
|
mapped_indices=[],
|
|
schema_counts=[],
|
|
original_lengths=[],
|
|
structure_labels=[],
|
|
task_types=[],
|
|
text_tokens=[],
|
|
schema_tokens_list=[],
|
|
start_mappings=[],
|
|
end_mappings=[],
|
|
original_texts=[],
|
|
original_schemas=[],
|
|
)
|
|
|
|
def _create_fallback_record(self, text: str, schema: Dict) -> TransformedRecord:
|
|
"""Create minimal valid record for failed transformations."""
|
|
dummy_tokens = [
|
|
"(", "[P]", "dummy", "(", "[E]", "entity", ")", ")"
|
|
]
|
|
format_result = self._format_input_with_mapping([dummy_tokens], ["."])
|
|
|
|
return TransformedRecord(
|
|
input_ids=format_result["input_ids"],
|
|
mapped_indices=format_result["mapped_indices"],
|
|
schema_tokens_list=[dummy_tokens],
|
|
text_tokens=["."],
|
|
structure_labels=[[1, [[(0, 0)]]]],
|
|
task_types=["entities"],
|
|
start_token_idx=[0],
|
|
end_token_idx=[1],
|
|
text=text or ".",
|
|
schema=schema or {},
|
|
)
|
|
|
|
# =========================================================================
|
|
# Internal: Schema Processing
|
|
# =========================================================================
|
|
|
|
def _build_classification_prefix(self, schema: Dict[str, Any]) -> List[str]:
|
|
"""Build classification prefix tokens."""
|
|
prefix_tokens = []
|
|
|
|
for struct in schema.get("json_structures", []):
|
|
for parent, fields in struct.items():
|
|
cls_fields = [
|
|
(fname, fval) for fname, fval in fields.items()
|
|
if isinstance(fval, dict) and "value" in fval and "choices" in fval
|
|
]
|
|
|
|
if self.is_training:
|
|
random.shuffle(cls_fields)
|
|
|
|
inner = []
|
|
for fname, fval in cls_fields:
|
|
choices = fval["choices"].copy()
|
|
if self.is_training:
|
|
random.shuffle(choices)
|
|
|
|
choice_tokens = []
|
|
for i, c in enumerate(choices):
|
|
if i > 0:
|
|
choice_tokens.append('|')
|
|
choice_tokens.append(c)
|
|
|
|
inner.extend([fname, '('] + choice_tokens + [')', ','])
|
|
|
|
if inner:
|
|
inner = inner[:-1]
|
|
prefix_tokens.extend(['(', f"{parent}:", *inner, ')'])
|
|
|
|
return prefix_tokens
|
|
|
|
def _wrap_classification_fields(self, schema: Dict, prefix: List[str]):
|
|
"""Wrap classification field values with [selection] prefix."""
|
|
|
|
def wrap(val):
|
|
if isinstance(val, list):
|
|
return [f"[selection]{v}" for v in val]
|
|
return f"[selection]{val}"
|
|
|
|
cls_keys = {
|
|
f"{parent}.{fname}"
|
|
for struct in schema.get("json_structures", [])
|
|
for parent, fields in struct.items()
|
|
for fname, fval in fields.items()
|
|
if isinstance(fval, dict) and {"value", "choices"} <= fval.keys()
|
|
}
|
|
|
|
for struct in schema.get("json_structures", []):
|
|
for parent, fields in struct.items():
|
|
for fname in list(fields):
|
|
key = f"{parent}.{fname}"
|
|
if key not in cls_keys:
|
|
continue
|
|
fval = fields[fname]
|
|
raw = fval["value"] if isinstance(fval, dict) else fval
|
|
fields[fname] = wrap(raw)
|
|
|
|
def _infer_from_json(self, schema: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Infer schemas and labels from JSON schema."""
|
|
schemas = []
|
|
labels = []
|
|
types = []
|
|
|
|
sampling = self.sampling_config if self.is_training else None
|
|
|
|
# Process JSON structures
|
|
self._process_json_structures(schema, schemas, labels, types, sampling)
|
|
|
|
# Process entities
|
|
self._process_entities(schema, schemas, labels, types, sampling)
|
|
|
|
# Process relations
|
|
self._process_relations(schema, schemas, labels, types, sampling)
|
|
|
|
# Process classifications
|
|
self._process_classifications(schema, schemas, labels, types, sampling)
|
|
|
|
# Shuffle task order during training
|
|
if sampling:
|
|
order = list(range(len(types)))
|
|
random.shuffle(order)
|
|
schemas = [schemas[i] for i in order]
|
|
labels = [labels[i] for i in order]
|
|
types = [types[i] for i in order]
|
|
|
|
return {
|
|
"schemas": schemas,
|
|
"structure_labels": labels,
|
|
"task_types": types,
|
|
"new_schema": schema
|
|
}
|
|
|
|
def _process_json_structures(self, schema, schemas, labels, types, sampling):
|
|
"""Process JSON structure schemas."""
|
|
if "json_structures" not in schema:
|
|
return
|
|
|
|
json_descs = schema.get("json_descriptions", {})
|
|
groups = {}
|
|
|
|
for item in schema["json_structures"]:
|
|
for parent, fields in item.items():
|
|
groups.setdefault(parent, []).append(fields)
|
|
|
|
for parent, occurrences in groups.items():
|
|
if sampling and random.random() < sampling.remove_json_structure_prob:
|
|
continue
|
|
|
|
all_fields = set()
|
|
for occ in occurrences:
|
|
all_fields.update(occ.keys())
|
|
common = list(all_fields)
|
|
|
|
if sampling and sampling.shuffle_json_fields:
|
|
random.shuffle(common)
|
|
|
|
chosen = [f for f in common if not (
|
|
sampling and random.random() < sampling.remove_json_field_prob
|
|
)]
|
|
if not chosen:
|
|
continue
|
|
|
|
# Handle synthetic labeling
|
|
real2syn = {}
|
|
descs = json_descs.get(parent, {})
|
|
example_modes = ["none", "descriptions"]
|
|
|
|
if sampling and random.random() < sampling.synthetic_entity_label_prob:
|
|
example_modes.remove("none")
|
|
synthetic = []
|
|
for i, real in enumerate(chosen, 1):
|
|
syn = f"field {i}"
|
|
real2syn[real] = syn
|
|
synthetic.append(syn)
|
|
descs = {real2syn.get(k, k): descs.get(k, k) for k in chosen}
|
|
chosen = synthetic
|
|
|
|
# Build spans
|
|
spans = []
|
|
for occ in occurrences:
|
|
span = [occ.get(f) for f in chosen]
|
|
spans.append(span)
|
|
|
|
# Dedup
|
|
uniq = []
|
|
seen = set()
|
|
for s in spans:
|
|
key = tuple(tuple(x) if isinstance(x, list) else x for x in s)
|
|
if key not in seen:
|
|
uniq.append(s)
|
|
seen.add(key)
|
|
|
|
# Check for empty
|
|
if all(all(c is None or c == "" for c in span) for span in uniq):
|
|
count = 0
|
|
uniq = []
|
|
else:
|
|
count = len(uniq)
|
|
|
|
labels.append([count, uniq])
|
|
|
|
mode = random.choice(example_modes) if self.is_training else (
|
|
"descriptions" if descs else "none"
|
|
)
|
|
|
|
schemas.append(self._transform_schema(
|
|
parent, chosen, self.C_TOKEN, label_descriptions=descs, example_mode=mode
|
|
))
|
|
types.append("json_structures")
|
|
|
|
def _process_entities(self, schema, schemas, labels, types, sampling):
|
|
"""Process entity schemas."""
|
|
if "entities" not in schema:
|
|
return
|
|
|
|
if sampling and random.random() < sampling.remove_entities_prob:
|
|
return
|
|
|
|
entity_fields = list(schema["entities"].keys())
|
|
descs = schema.get("entity_descriptions", {})
|
|
example_modes = ["none", "descriptions"]
|
|
|
|
real2syn = {}
|
|
if sampling and random.random() < sampling.synthetic_entity_label_prob:
|
|
example_modes.remove("none")
|
|
synthetic = []
|
|
for i, real in enumerate(entity_fields, 1):
|
|
syn = f"entity {i}"
|
|
real2syn[real] = syn
|
|
synthetic.append(syn)
|
|
descs = {real2syn.get(k, k): v for k, v in descs.items()}
|
|
schema["entities"] = {real2syn.get(k, k): v for k, v in schema["entities"].items()}
|
|
entity_fields = synthetic
|
|
|
|
if sampling and sampling.shuffle_entities:
|
|
random.shuffle(entity_fields)
|
|
|
|
chosen = [e for e in entity_fields if not (
|
|
sampling and random.random() < sampling.remove_entity_prob
|
|
)]
|
|
|
|
if chosen:
|
|
span = [schema["entities"][e] for e in chosen]
|
|
labels.append([1, [span]])
|
|
|
|
mode = random.choice(example_modes) if self.is_training else (
|
|
"descriptions" if descs else "none"
|
|
)
|
|
|
|
schemas.append(self._transform_schema(
|
|
"entities", chosen, self.E_TOKEN, label_descriptions=descs, example_mode=mode
|
|
))
|
|
types.append("entities")
|
|
|
|
def _process_relations(self, schema, schemas, labels, types, sampling):
|
|
"""Process relation schemas."""
|
|
if "relations" not in schema:
|
|
return
|
|
|
|
groups = {}
|
|
for item in schema["relations"]:
|
|
if sampling and random.random() < sampling.remove_relations_prob:
|
|
continue
|
|
for parent, fields in item.items():
|
|
groups.setdefault(parent, []).append(fields)
|
|
|
|
for parent, occurrences in groups.items():
|
|
field_names = list(occurrences[0].keys())
|
|
|
|
if sampling and "head" in field_names and "tail" in field_names:
|
|
if random.random() < sampling.swap_head_tail_prob:
|
|
idx_h = field_names.index("head")
|
|
idx_t = field_names.index("tail")
|
|
field_names[idx_h], field_names[idx_t] = field_names[idx_t], field_names[idx_h]
|
|
|
|
spans = []
|
|
for occ in occurrences:
|
|
if all(f in occ for f in field_names):
|
|
spans.append([occ[f] for f in field_names])
|
|
|
|
if not spans:
|
|
continue
|
|
|
|
# Dedup
|
|
seen = set()
|
|
uniq = []
|
|
for span in spans:
|
|
t = tuple(tuple(s) if isinstance(s, list) else s for s in span)
|
|
if t not in seen:
|
|
seen.add(t)
|
|
uniq.append(span)
|
|
|
|
labels.append([len(uniq), uniq])
|
|
schemas.append(self._transform_schema(parent, field_names, self.R_TOKEN))
|
|
types.append("relations")
|
|
|
|
def _process_classifications(self, schema, schemas, labels, types, sampling):
|
|
"""Process classification schemas."""
|
|
if "classifications" not in schema:
|
|
return
|
|
|
|
for idx, item in enumerate(schema["classifications"]):
|
|
if sampling and random.random() < sampling.remove_classification_prob:
|
|
continue
|
|
|
|
cls_labels = item["labels"].copy()
|
|
examples = item.get("examples", [])
|
|
descs = item.get("label_descriptions", {}) or {}
|
|
|
|
real2syn = {}
|
|
example_modes = ["few_shot", "descriptions", "both", "none"] if self.is_training else ["both"]
|
|
|
|
if sampling and random.random() < sampling.synthetic_label_prob:
|
|
example_modes = [m for m in example_modes if m != "none"]
|
|
synthetic = []
|
|
for i, real in enumerate(cls_labels, 1):
|
|
syn = f"label {i}"
|
|
real2syn[real] = syn
|
|
synthetic.append(syn)
|
|
cls_labels = synthetic
|
|
descs = {real2syn.get(k, k): descs.get(k, k) for k in item["labels"]}
|
|
examples = [(inp, real2syn.get(out, out)) for inp, out in examples]
|
|
|
|
mode = random.choice(example_modes) if example_modes else "none"
|
|
|
|
# Label dropping
|
|
if sampling and hasattr(sampling, "remove_classification_label_prob"):
|
|
drop_frac = random.betavariate(1, 1) * sampling.remove_classification_label_prob
|
|
num_remove = int(len(cls_labels) * drop_frac)
|
|
if num_remove > 0:
|
|
cls_labels = random.sample(cls_labels, len(cls_labels) - num_remove)
|
|
|
|
max_labels = sampling.max_num_labels // 2 if mode in ["few_shot", "both",
|
|
"descriptions"] else sampling.max_num_labels
|
|
if len(cls_labels) > max_labels:
|
|
cls_labels = cls_labels[:max_labels]
|
|
|
|
if random.random() < sampling.include_true_label_prob:
|
|
true_label = item.get("true_label", [])
|
|
if isinstance(true_label, list):
|
|
for tl in true_label:
|
|
if tl not in cls_labels:
|
|
cls_labels.append(tl)
|
|
elif true_label not in cls_labels:
|
|
cls_labels.append(true_label)
|
|
|
|
if sampling and sampling.shuffle_classification_labels:
|
|
random.shuffle(cls_labels)
|
|
|
|
schemas.append(self._transform_schema(
|
|
item["task"], cls_labels, self.L_TOKEN,
|
|
prompt=item.get("prompt"), examples=examples,
|
|
label_descriptions=descs, example_mode=mode
|
|
))
|
|
types.append("classifications")
|
|
|
|
# Update schema
|
|
schema["classifications"][idx]["labels"] = cls_labels
|
|
true_label = schema["classifications"][idx]["true_label"].copy()
|
|
schema["classifications"][idx]["true_label"] = [real2syn.get(i, i) for i in true_label]
|
|
labels.append([])
|
|
|
|
def _transform_schema(
|
|
self,
|
|
parent: str,
|
|
fields: List[str],
|
|
child_prefix: str,
|
|
prompt: str = None,
|
|
examples: List[Tuple[str, str]] = None,
|
|
label_descriptions: Dict[str, str] = None,
|
|
example_mode: str = "both"
|
|
) -> List[str]:
|
|
"""Transform schema into token sequence."""
|
|
prompt_str = parent
|
|
if prompt:
|
|
prompt_str = f"{parent}: {prompt}"
|
|
|
|
if example_mode in ["descriptions", "both"] and label_descriptions:
|
|
descs = [(l, d) for l, d in label_descriptions.items() if l in fields]
|
|
if self.is_training:
|
|
random.shuffle(descs)
|
|
for label, desc in descs:
|
|
prompt_str += f" {self.DESC_TOKEN} {label}: {desc}"
|
|
|
|
if example_mode in ["few_shot", "both"] and examples:
|
|
if self.is_training:
|
|
random.shuffle(examples)
|
|
for inp, out in examples:
|
|
if out in fields:
|
|
out_str = out if isinstance(out, str) else ', '.join(out)
|
|
prompt_str += f" {self.EXAMPLE_TOKEN} {inp} {self.OUTPUT_TOKEN} {out_str}"
|
|
|
|
tokens = ["(", self.P_TOKEN, prompt_str, "("]
|
|
for field in fields:
|
|
tokens.extend([child_prefix, field])
|
|
tokens.extend([")", ")"])
|
|
|
|
return tokens
|
|
|
|
def _build_outputs(
|
|
self,
|
|
processed: Dict,
|
|
schema: Dict,
|
|
text_tokens: List[str],
|
|
len_prefix: int
|
|
) -> List[Dict]:
|
|
"""Build output labels for each schema."""
|
|
results = []
|
|
|
|
for schema_tokens, task_type, struct_label in zip(
|
|
processed["schemas"],
|
|
processed["task_types"],
|
|
processed["structure_labels"]
|
|
):
|
|
if task_type != "classifications":
|
|
count, spans = struct_label
|
|
transformed = []
|
|
|
|
for span in spans:
|
|
positions = []
|
|
for field in span:
|
|
if isinstance(field, list):
|
|
nested = []
|
|
for sub in field:
|
|
if str(sub).startswith("[selection]"):
|
|
# Use case-insensitive matching for choice fields
|
|
pos = self._find_sublist(
|
|
[str(sub)[11:]], text_tokens[:len_prefix],
|
|
case_insensitive=True
|
|
)
|
|
else:
|
|
pos = self._find_sublist(
|
|
self._tokenize_text(str(sub)), text_tokens
|
|
)
|
|
nested.extend(pos)
|
|
positions.append(nested)
|
|
else:
|
|
if str(field).startswith("[selection]"):
|
|
# Use case-insensitive matching for choice fields
|
|
pos = self._find_sublist(
|
|
[str(field)[11:]], text_tokens[:len_prefix],
|
|
case_insensitive=True
|
|
)
|
|
else:
|
|
pos = self._find_sublist(
|
|
self._tokenize_text(str(field)), text_tokens
|
|
)
|
|
positions.append(pos)
|
|
transformed.append(positions)
|
|
|
|
results.append({
|
|
"task_type": task_type,
|
|
"schema_tokens": schema_tokens,
|
|
"output": [count, transformed]
|
|
})
|
|
else:
|
|
cls_item = next(
|
|
(c for c in schema["classifications"] if schema_tokens[2].startswith(c["task"])),
|
|
None
|
|
)
|
|
if cls_item is None:
|
|
raise ValueError(f"Missing classification for: {schema_tokens[2]}")
|
|
|
|
bool_labels = [1 if l in cls_item["true_label"] else 0 for l in cls_item["labels"]]
|
|
results.append({
|
|
"task_type": task_type,
|
|
"schema_tokens": schema_tokens,
|
|
"output": bool_labels
|
|
})
|
|
|
|
return results
|
|
|
|
def _find_sublist(
|
|
self,
|
|
sub: List[str],
|
|
lst: List[str],
|
|
case_insensitive: bool = False
|
|
) -> List[Tuple[int, int]]:
|
|
"""Find all occurrences of sublist in list.
|
|
|
|
Args:
|
|
sub: Sublist to search for
|
|
lst: List to search in
|
|
case_insensitive: If True, use case-insensitive matching
|
|
"""
|
|
if not sub or all(t == "" for t in sub):
|
|
return [(-1, -1)]
|
|
|
|
sub_len = len(sub)
|
|
|
|
if case_insensitive:
|
|
sub_lower = [s.lower() for s in sub]
|
|
matches = [
|
|
(i, i + sub_len - 1)
|
|
for i in range(len(lst) - sub_len + 1)
|
|
if [t.lower() for t in lst[i:i + sub_len]] == sub_lower
|
|
]
|
|
else:
|
|
matches = [
|
|
(i, i + sub_len - 1)
|
|
for i in range(len(lst) - sub_len + 1)
|
|
if lst[i:i + sub_len] == sub
|
|
]
|
|
return matches or [(-1, -1)]
|
|
|
|
def _tokenize_text(self, text: str) -> List[str]:
|
|
"""Tokenize text into words."""
|
|
return [tok for tok, _, _ in self.word_splitter(text, lower=True)]
|
|
|
|
# =========================================================================
|
|
# Internal: Input Formatting
|
|
# =========================================================================
|
|
|
|
def _format_input_with_mapping(
|
|
self,
|
|
schema_tokens_list: List[List[str]],
|
|
text_tokens: List[str]
|
|
) -> Dict[str, Any]:
|
|
"""Format input and create token mappings."""
|
|
# Build combined tokens
|
|
combined = []
|
|
for struct in schema_tokens_list:
|
|
combined.extend(struct)
|
|
combined.append(self.SEP_STRUCT)
|
|
if combined:
|
|
combined.pop()
|
|
combined.append(self.SEP_TEXT)
|
|
combined.extend(text_tokens)
|
|
|
|
# Build subword list and mappings
|
|
subwords = []
|
|
mappings = []
|
|
|
|
num_schemas = len(schema_tokens_list)
|
|
text_schema_idx = num_schemas
|
|
current_schema = 0
|
|
found_sep = False
|
|
|
|
for orig_idx, token in enumerate(combined):
|
|
if token == self.SEP_TEXT:
|
|
seg_type = "sep"
|
|
schema_idx = text_schema_idx
|
|
found_sep = True
|
|
elif not found_sep:
|
|
seg_type = "schema"
|
|
schema_idx = current_schema
|
|
if token == self.SEP_STRUCT:
|
|
current_schema += 1
|
|
else:
|
|
seg_type = "text"
|
|
schema_idx = text_schema_idx
|
|
|
|
# OPT-6: Use cached tokenizations for special tokens and punctuation
|
|
if token in self._token_cache:
|
|
sub_tokens = self._token_cache[token]
|
|
else:
|
|
sub_tokens = self.tokenizer.tokenize(token)
|
|
subwords.extend(sub_tokens)
|
|
mappings.extend([(seg_type, orig_idx, schema_idx)] * len(sub_tokens))
|
|
|
|
input_ids = self.tokenizer.convert_tokens_to_ids(subwords)
|
|
|
|
return {
|
|
"input_ids": input_ids,
|
|
"mapped_indices": mappings,
|
|
"subword_list": subwords
|
|
}
|
|
|
|
# =========================================================================
|
|
# Embedding Extraction (Called by Model)
|
|
# =========================================================================
|
|
|
|
def extract_embeddings_from_batch(
|
|
self,
|
|
token_embeddings: torch.Tensor,
|
|
input_ids: torch.Tensor,
|
|
batch: PreprocessedBatch
|
|
) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
|
|
"""
|
|
Extract token and schema embeddings from encoded batch.
|
|
|
|
Args:
|
|
token_embeddings: (batch, seq_len, hidden) from encoder
|
|
input_ids: (batch, seq_len) input token IDs
|
|
batch: PreprocessedBatch with metadata
|
|
|
|
Returns:
|
|
- all_token_embs: List of (text_len, hidden) per sample
|
|
- all_schema_embs: List of schema embeddings per sample
|
|
"""
|
|
all_token_embs = []
|
|
all_schema_embs = []
|
|
|
|
# OPT-1: Use pre-computed special token IDs instead of string comparison
|
|
special_ids = self._special_ids
|
|
|
|
for i in range(len(batch)):
|
|
seq_len = batch.original_lengths[i]
|
|
embs = token_embeddings[i, :seq_len, :]
|
|
ids = input_ids[i, :seq_len].tolist()
|
|
mappings = batch.mapped_indices[i][:seq_len]
|
|
num_schemas = batch.schema_counts[i]
|
|
|
|
schema_embs = [[] for _ in range(num_schemas)]
|
|
word_embs = []
|
|
bucket = []
|
|
last_orig = None
|
|
|
|
for j, tid in enumerate(ids):
|
|
seg_type, orig_idx, schema_idx = mappings[j]
|
|
emb = embs[j]
|
|
|
|
if seg_type == "schema":
|
|
if tid in special_ids:
|
|
schema_embs[schema_idx].append(emb)
|
|
|
|
elif seg_type == "text":
|
|
if last_orig is not None and orig_idx != last_orig and bucket:
|
|
word_embs.append(self._aggregate(bucket))
|
|
bucket = []
|
|
bucket.append(emb)
|
|
last_orig = orig_idx
|
|
|
|
if bucket:
|
|
word_embs.append(self._aggregate(bucket))
|
|
|
|
all_token_embs.append(
|
|
torch.stack(word_embs) if word_embs else torch.empty(0, embs.shape[-1], device=embs.device)
|
|
)
|
|
all_schema_embs.append(schema_embs)
|
|
|
|
return all_token_embs, all_schema_embs
|
|
|
|
def _aggregate(self, pieces: List[torch.Tensor]) -> torch.Tensor:
|
|
"""Aggregate subword embeddings."""
|
|
# OPT-10: Short-circuit for single subword tokens (common case)
|
|
if len(pieces) == 1:
|
|
return pieces[0]
|
|
if self.token_pooling == "first":
|
|
return pieces[0]
|
|
stack = torch.stack(pieces)
|
|
if self.token_pooling == "mean":
|
|
return stack.mean(dim=0)
|
|
if self.token_pooling == "max":
|
|
return stack.max(dim=0).values
|
|
return pieces[0] |