agent-smith/packages/GLiNER2/gliner2/inference/engine.py
2026-03-06 12:59:32 +01:00

1458 lines
57 KiB
Python

"""
GLiNER2 - Advanced Information Extraction Engine
This module provides the main GLiNER2 class with optimized batch processing
using DataLoader-based parallel preprocessing.
Example:
>>> from gliner2 import GLiNER2
>>>
>>> extractor = GLiNER2.from_pretrained("model-repo")
>>>
>>> # Simple extraction
>>> results = extractor.extract_entities(
... "Apple released iPhone 15.",
... ["company", "product"]
... )
>>>
>>> # Batch extraction (parallel preprocessing)
>>> results = extractor.batch_extract_entities(
... texts_list,
... ["company", "product"],
... batch_size=32,
... num_workers=4
... )
"""
from __future__ import annotations
import re
import hashlib
import json
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union, Tuple, TYPE_CHECKING, Pattern, Literal
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from gliner2.model import Extractor
from gliner2.processor import PreprocessedBatch
from gliner2.training.trainer import ExtractorCollator
if TYPE_CHECKING:
from gliner2.api_client import GLiNER2API
# =============================================================================
# Validators
# =============================================================================
@dataclass
class RegexValidator:
"""Regex-based span filter for post-processing."""
pattern: str | Pattern[str]
mode: Literal["full", "partial"] = "full"
exclude: bool = False
flags: int = re.IGNORECASE
_compiled: Pattern[str] = field(init=False, repr=False)
def __post_init__(self):
if self.mode not in {"full", "partial"}:
raise ValueError(f"mode must be 'full' or 'partial', got {self.mode!r}")
try:
compiled = (
self.pattern if isinstance(self.pattern, re.Pattern)
else re.compile(self.pattern, self.flags)
)
except re.error as err:
raise ValueError(f"Invalid regex: {self.pattern!r}") from err
object.__setattr__(self, "_compiled", compiled)
def __call__(self, text: str) -> bool:
return self.validate(text)
def validate(self, text: str) -> bool:
matcher = self._compiled.fullmatch if self.mode == "full" else self._compiled.search
matched = matcher(text) is not None
return not matched if self.exclude else matched
# =============================================================================
# Schema Builder
# =============================================================================
class StructureBuilder:
"""Builder for structured data schemas."""
def __init__(self, schema: 'Schema', parent: str):
self.schema = schema
self.parent = parent
self.fields = OrderedDict()
self.descriptions = OrderedDict()
self.field_order = []
self._finished = False
def field(
self,
name: str,
dtype: Literal["str", "list"] = "list",
choices: Optional[List[str]] = None,
description: Optional[str] = None,
threshold: Optional[float] = None,
validators: Optional[List[RegexValidator]] = None
) -> 'StructureBuilder':
"""Add a field to the structure."""
self.fields[name] = {"value": "", "choices": choices} if choices else ""
self.field_order.append(name)
if description:
self.descriptions[name] = description
self.schema._store_field_metadata(self.parent, name, dtype, threshold, choices, validators)
return self
def _auto_finish(self):
if not self._finished:
self.schema._store_field_order(self.parent, self.field_order)
self.schema.schema["json_structures"].append({self.parent: self.fields})
if self.descriptions:
if "json_descriptions" not in self.schema.schema:
self.schema.schema["json_descriptions"] = {}
self.schema.schema["json_descriptions"][self.parent] = self.descriptions
self._finished = True
def __getattr__(self, name):
if hasattr(self.schema, name):
self._auto_finish()
return getattr(self.schema, name)
raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")
class Schema:
"""Schema builder for extraction tasks."""
def __init__(self):
self.schema = {
"json_structures": [],
"classifications": [],
"entities": OrderedDict(),
"relations": [],
"json_descriptions": {},
"entity_descriptions": OrderedDict()
}
self._field_metadata = {}
self._entity_metadata = {}
self._relation_metadata = {}
self._field_orders = {}
self._entity_order = []
self._relation_order = []
self._active_builder = None
def _store_field_metadata(self, parent, field, dtype, threshold, choices, validators=None):
if threshold is not None and not 0 <= threshold <= 1:
raise ValueError(f"Threshold must be 0-1, got {threshold}")
self._field_metadata[f"{parent}.{field}"] = {
"dtype": dtype, "threshold": threshold, "choices": choices,
"validators": validators or []
}
def _store_entity_metadata(self, entity, dtype, threshold):
if threshold is not None and not 0 <= threshold <= 1:
raise ValueError(f"Threshold must be 0-1, got {threshold}")
self._entity_metadata[entity] = {"dtype": dtype, "threshold": threshold}
def _store_field_order(self, parent, order):
self._field_orders[parent] = order
def structure(self, name: str) -> StructureBuilder:
"""Start building a structure schema."""
if self._active_builder:
self._active_builder._auto_finish()
self._active_builder = StructureBuilder(self, name)
return self._active_builder
def classification(
self,
task: str,
labels: Union[List[str], Dict[str, str]],
multi_label: bool = False,
cls_threshold: float = 0.5,
**kwargs
) -> 'Schema':
"""Add classification task."""
if self._active_builder:
self._active_builder._auto_finish()
self._active_builder = None
label_names = list(labels.keys()) if isinstance(labels, dict) else labels
label_descs = labels if isinstance(labels, dict) else None
config = {
"task": task, "labels": label_names,
"multi_label": multi_label, "cls_threshold": cls_threshold,
"true_label": ["N/A"], **kwargs
}
if label_descs:
config["label_descriptions"] = label_descs
self.schema["classifications"].append(config)
return self
def entities(
self,
entity_types: Union[str, List[str], Dict[str, Union[str, Dict]]],
dtype: Literal["str", "list"] = "list",
threshold: Optional[float] = None
) -> 'Schema':
"""Add entity extraction task."""
if self._active_builder:
self._active_builder._auto_finish()
self._active_builder = None
entities = self._parse_entity_input(entity_types)
for name, config in entities.items():
self.schema["entities"][name] = ""
if name not in self._entity_order:
self._entity_order.append(name)
self._store_entity_metadata(
name,
config.get("dtype", dtype),
config.get("threshold", threshold)
)
if "description" in config:
self.schema["entity_descriptions"][name] = config["description"]
return self
def _parse_entity_input(self, entity_types):
if isinstance(entity_types, str):
return {entity_types: {}}
elif isinstance(entity_types, list):
return {name: {} for name in entity_types}
elif isinstance(entity_types, dict):
result = {}
for name, config in entity_types.items():
if isinstance(config, str):
result[name] = {"description": config}
elif isinstance(config, dict):
result[name] = config
else:
result[name] = {}
return result
raise ValueError("Invalid entity_types format")
def relations(
self,
relation_types: Union[str, List[str], Dict[str, Union[str, Dict]]],
threshold: Optional[float] = None
) -> 'Schema':
"""Add relation extraction task."""
if self._active_builder:
self._active_builder._auto_finish()
self._active_builder = None
if isinstance(relation_types, str):
relations = {relation_types: {}}
elif isinstance(relation_types, list):
relations = {name: {} for name in relation_types}
elif isinstance(relation_types, dict):
relations = {}
for name, config in relation_types.items():
relations[name] = {"description": config} if isinstance(config, str) else (config if isinstance(config, dict) else {})
else:
raise ValueError("Invalid relation_types format")
for name, config in relations.items():
self.schema["relations"].append({name: {"head": "", "tail": ""}})
if name not in self._relation_order:
self._relation_order.append(name)
self._field_orders[name] = ["head", "tail"]
rel_threshold = config.get("threshold", threshold)
if rel_threshold is not None and not 0 <= rel_threshold <= 1:
raise ValueError(f"Threshold must be 0-1, got {rel_threshold}")
self._relation_metadata[name] = {"threshold": rel_threshold}
return self
def build(self) -> Dict[str, Any]:
"""Build final schema dictionary."""
if self._active_builder:
self._active_builder._auto_finish()
self._active_builder = None
return self.schema
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Schema':
"""Create a Schema from a dictionary.
Args:
data: Dictionary with optional keys: entities, structures,
classifications, relations
Returns:
Schema: Constructed schema instance
Raises:
ValidationError: If the input data is invalid
Example:
>>> schema_dict = {
... "entities": ["company", "person"],
... "structures": {
... "product_info": {
... "fields": [
... {"name": "company", "dtype": "str"},
... {"name": "product"}
... ]
... }
... },
... "classifications": [
... {"task": "sentiment", "labels": ["positive", "negative"]}
... ],
... "relations": ["works_for", "founded_by"]
... }
>>> schema = Schema.from_dict(schema_dict)
"""
from gliner2.inference.schema_model import SchemaInput
# Validate input
validated = SchemaInput(**data)
# Build schema using builder API
schema = cls()
# Add entities
if validated.entities is not None:
schema.entities(validated.entities)
# Add structures
if validated.structures is not None:
for struct_name, struct_input in validated.structures.items():
builder = schema.structure(struct_name)
for field_input in struct_input.fields:
builder.field(
name=field_input.name,
dtype=field_input.dtype,
choices=field_input.choices,
description=field_input.description
)
# Auto-finish the builder
builder._auto_finish()
# Add classifications
if validated.classifications is not None:
for cls_input in validated.classifications:
schema.classification(
task=cls_input.task,
labels=cls_input.labels,
multi_label=cls_input.multi_label
)
# Add relations
if validated.relations is not None:
schema.relations(validated.relations)
return schema
@classmethod
def from_json(cls, json_str: str) -> 'Schema':
"""Create a Schema from a JSON string.
Args:
json_str: JSON string with schema definition
Returns:
Schema: Constructed schema instance
Raises:
ValidationError: If the input data is invalid
json.JSONDecodeError: If the JSON is malformed
Example:
>>> schema_json = '''
... {
... "entities": ["company", "person"],
... "classifications": [
... {"task": "sentiment", "labels": ["positive", "negative"]}
... ]
... }
... '''
>>> schema = Schema.from_json(schema_json)
"""
data = json.loads(json_str)
return cls.from_dict(data)
def to_dict(self) -> Dict[str, Any]:
"""Convert schema to user-friendly dictionary format.
Returns:
Dict: Schema in dictionary format compatible with from_dict()
Example:
>>> schema = Schema()
>>> schema.entities(["company", "person"])
>>> schema_dict = schema.to_dict()
>>> # schema_dict can be used with Schema.from_dict()
"""
result = {}
# Export entities
if self.schema["entities"]:
# Check if we have descriptions
if self.schema["entity_descriptions"]:
result["entities"] = dict(self.schema["entity_descriptions"])
else:
result["entities"] = list(self.schema["entities"].keys())
# Export structures
if self.schema["json_structures"]:
result["structures"] = {}
for struct_dict in self.schema["json_structures"]:
for struct_name, struct_fields in struct_dict.items():
fields = []
field_order = self._field_orders.get(struct_name, [])
for field_name in field_order:
if field_name not in struct_fields:
continue
field_key = f"{struct_name}.{field_name}"
metadata = self._field_metadata.get(field_key, {})
field_def = {"name": field_name}
# Add dtype if not default
dtype = metadata.get("dtype", "list")
if dtype != "list":
field_def["dtype"] = dtype
# Add choices if present
choices = metadata.get("choices")
if choices:
field_def["choices"] = choices
# Add description if present
desc = self.schema.get("json_descriptions", {}).get(struct_name, {}).get(field_name)
if desc:
field_def["description"] = desc
fields.append(field_def)
result["structures"][struct_name] = {"fields": fields}
# Export classifications
if self.schema["classifications"]:
result["classifications"] = []
for cls_config in self.schema["classifications"]:
cls_def = {
"task": cls_config["task"],
"labels": cls_config["labels"]
}
if cls_config.get("multi_label", False):
cls_def["multi_label"] = True
result["classifications"].append(cls_def)
# Export relations
if self.schema["relations"]:
result["relations"] = self._relation_order if self._relation_order else [
list(rel_dict.keys())[0] for rel_dict in self.schema["relations"]
]
return result
# =============================================================================
# Main GLiNER2 Class
# =============================================================================
class GLiNER2(Extractor):
"""
GLiNER2 Information Extraction Model.
Provides efficient batch extraction with parallel preprocessing.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._schema_cache = {}
# OPT-11: Cached collator instance for inference
self._inference_collator = None
@classmethod
def from_api(cls, api_key: str = None, api_base_url: str = None,
timeout: float = 30.0, max_retries: int = 3) -> 'GLiNER2API':
"""Load from API instead of local model."""
from gliner2.api_client import GLiNER2API
return GLiNER2API(api_key=api_key, api_base_url=api_base_url,
timeout=timeout, max_retries=max_retries)
def create_schema(self) -> Schema:
"""Create a new schema builder."""
return Schema()
# =========================================================================
# Main Batch Extraction
# =========================================================================
@torch.inference_mode()
def batch_extract(
self,
texts: List[str],
schemas: Union[Schema, List[Schema], Dict, List[Dict]],
batch_size: int = 8,
threshold: float = 0.5,
num_workers: int = 0,
format_results: bool = True,
include_confidence: bool = False,
include_spans: bool = False
) -> List[Dict[str, Any]]:
"""
Extract from multiple texts with parallel preprocessing.
Args:
texts: List of input texts
schemas: Single schema or list of schemas
batch_size: Batch size for processing
threshold: Confidence threshold
num_workers: Workers for parallel preprocessing
format_results: Format output nicely
include_confidence: Include confidence scores
include_spans: Include character-level start/end positions
Returns:
List of extraction results
"""
if not texts:
return []
self.eval()
self.processor.change_mode(is_training=False)
# Normalize schemas
if isinstance(schemas, list):
if len(schemas) != len(texts):
raise ValueError(f"Schema count ({len(schemas)}) != text count ({len(texts)})")
schema_list = schemas
else:
schema_list = [schemas] * len(texts)
# Build schema dicts and metadata
schema_dicts = []
metadata_list = []
for schema in schema_list:
if hasattr(schema, 'build'):
schema_dict = schema.build()
# Extract classification task names
classification_tasks = [c["task"] for c in schema_dict.get("classifications", [])]
metadata = {
"field_metadata": schema._field_metadata,
"entity_metadata": schema._entity_metadata,
"relation_metadata": getattr(schema, '_relation_metadata', {}),
"field_orders": schema._field_orders,
"entity_order": schema._entity_order,
"relation_order": getattr(schema, '_relation_order', []),
"classification_tasks": classification_tasks
}
else:
schema_dict = schema
# Extract classification task names from dict schema
classification_tasks = [c["task"] for c in schema_dict.get("classifications", [])]
metadata = {
"field_metadata": {}, "entity_metadata": {},
"relation_metadata": {}, "field_orders": {},
"entity_order": [], "relation_order": [],
"classification_tasks": classification_tasks
}
# Ensure classifications have true_label
for cls_config in schema_dict.get("classifications", []):
cls_config.setdefault("true_label", ["N/A"])
schema_dicts.append(schema_dict)
metadata_list.append(metadata)
# OPT-9: Skip duplicate normalization — _collate_batch handles it
dataset = list(zip(texts, schema_dicts))
# OPT-11: Reuse cached collator instance
if self._inference_collator is None:
self._inference_collator = ExtractorCollator(self.processor, is_training=False)
collator = self._inference_collator
# OPT-12: Skip DataLoader overhead for single-batch inputs
if len(dataset) <= batch_size and num_workers == 0:
batches = [collator(dataset)]
else:
batches = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collator,
pin_memory=True if torch.cuda.is_available() else False,
)
# Process batches
all_results = []
sample_idx = 0
device = next(self.parameters()).device
for batch in batches:
batch = batch.to(device)
batch_results = self._extract_from_batch(
batch, threshold, metadata_list[sample_idx:sample_idx + len(batch)],
include_confidence, include_spans
)
if format_results:
for i, result in enumerate(batch_results):
meta = metadata_list[sample_idx + i]
requested_relations = meta.get("relation_order", [])
classification_tasks = meta.get("classification_tasks", [])
batch_results[i] = self.format_results(
result, include_confidence, requested_relations, classification_tasks
)
all_results.extend(batch_results)
sample_idx += len(batch)
return all_results
def _extract_from_batch(
self,
batch: PreprocessedBatch,
threshold: float,
metadata_list: List[Dict],
include_confidence: bool,
include_spans: bool
) -> List[Dict[str, Any]]:
"""Extract from preprocessed batch."""
# Encode batch
all_token_embs, all_schema_embs = self.processor.extract_embeddings_from_batch(
self.encoder(
input_ids=batch.input_ids,
attention_mask=batch.attention_mask
).last_hidden_state,
batch.input_ids,
batch
)
results = []
for i in range(len(batch)):
try:
sample_result = self._extract_sample(
token_embs=all_token_embs[i],
schema_embs=all_schema_embs[i],
schema_tokens_list=batch.schema_tokens_list[i],
task_types=batch.task_types[i],
text_tokens=batch.text_tokens[i],
original_text=batch.original_texts[i],
schema=batch.original_schemas[i],
start_mapping=batch.start_mappings[i],
end_mapping=batch.end_mappings[i],
threshold=threshold,
metadata=metadata_list[i],
include_confidence=include_confidence,
include_spans=include_spans
)
results.append(sample_result)
except Exception as e:
print(f"Error extracting sample {i}: {e}")
results.append({})
return results
def _extract_sample(
self,
token_embs: torch.Tensor,
schema_embs: List[List[torch.Tensor]],
schema_tokens_list: List[List[str]],
task_types: List[str],
text_tokens: List[str],
original_text: str,
schema: Dict,
start_mapping: List[int],
end_mapping: List[int],
threshold: float,
metadata: Dict,
include_confidence: bool,
include_spans: bool
) -> Dict[str, Any]:
"""Extract from single sample."""
results = {}
# Compute span representations if needed
has_span_task = any(t != "classifications" for t in task_types)
span_info = None
if has_span_task and token_embs.numel() > 0:
span_info = self.compute_span_rep(token_embs)
# Build classification field map
cls_fields = {}
for struct in schema.get("json_structures", []):
for parent, fields in struct.items():
for fname, fval in fields.items():
if isinstance(fval, dict) and "choices" in fval:
cls_fields[f"{parent}.{fname}"] = fval["choices"]
# OPT-3: Use start_mapping length instead of re-tokenizing text
text_len = len(start_mapping)
for i, (schema_tokens, task_type) in enumerate(zip(schema_tokens_list, task_types)):
if len(schema_tokens) < 4 or not schema_embs[i]:
continue
schema_name = schema_tokens[2].split(" [DESCRIPTION] ")[0]
embs = torch.stack(schema_embs[i])
if task_type == "classifications":
self._extract_classification_result(
results, schema_name, schema, embs, schema_tokens
)
else:
self._extract_span_result(
results, schema_name, task_type, embs, span_info,
schema_tokens, text_tokens, text_len, original_text,
start_mapping, end_mapping, threshold, metadata,
cls_fields, include_confidence, include_spans
)
return results
def _extract_classification_result(
self,
results: Dict,
schema_name: str,
schema: Dict,
embs: torch.Tensor,
schema_tokens: List[str]
):
"""Extract classification result."""
cls_config = next(
c for c in schema["classifications"]
if schema_tokens[2].startswith(c["task"])
)
cls_embeds = embs[1:]
logits = self.classifier(cls_embeds).squeeze(-1)
activation = cls_config.get("class_act", "auto")
is_multi = cls_config.get("multi_label", False)
if activation == "sigmoid":
probs = torch.sigmoid(logits)
elif activation == "softmax":
probs = torch.softmax(logits, dim=-1)
else:
probs = torch.sigmoid(logits) if is_multi else torch.softmax(logits, dim=-1)
labels = cls_config["labels"]
cls_threshold = cls_config.get("cls_threshold", 0.5)
if is_multi:
chosen = [(labels[j], probs[j].item()) for j in range(len(labels)) if probs[j].item() >= cls_threshold]
if not chosen:
best = int(torch.argmax(probs).item())
chosen = [(labels[best], probs[best].item())]
results[schema_name] = chosen
else:
best = int(torch.argmax(probs).item())
results[schema_name] = (labels[best], probs[best].item())
def _extract_span_result(
self,
results: Dict,
schema_name: str,
task_type: str,
embs: torch.Tensor,
span_info: Dict,
schema_tokens: List[str],
text_tokens: List[str],
text_len: int,
original_text: str,
start_mapping: List[int],
end_mapping: List[int],
threshold: float,
metadata: Dict,
cls_fields: Dict,
include_confidence: bool,
include_spans: bool
):
"""Extract span-based results."""
# Get field names
field_names = []
for j in range(len(schema_tokens) - 1):
if schema_tokens[j] in ("[E]", "[C]", "[R]"):
field_names.append(schema_tokens[j + 1])
if not field_names:
results[schema_name] = [] if schema_name == "entities" else {}
return
# Predict count
count_logits = self.count_pred(embs[0].unsqueeze(0))
pred_count = int(count_logits.argmax(dim=1).item())
if pred_count <= 0 or span_info is None:
if schema_name == "entities":
results[schema_name] = []
elif task_type == "relations":
results[schema_name] = []
else:
results[schema_name] = {}
return
# Get span scores
struct_proj = self.count_embed(embs[1:], pred_count)
span_scores = torch.sigmoid(
torch.einsum("lkd,bpd->bplk", span_info["span_rep"], struct_proj)
)
# Extract based on type
if schema_name == "entities":
results[schema_name] = self._extract_entities(
field_names, span_scores, text_len, text_tokens,
original_text, start_mapping, end_mapping,
threshold, metadata, include_confidence, include_spans
)
elif task_type == "relations":
results[schema_name] = self._extract_relations(
schema_name, field_names, span_scores, pred_count,
text_len, text_tokens, original_text, start_mapping, end_mapping,
threshold, metadata, include_confidence, include_spans
)
else:
results[schema_name] = self._extract_structures(
schema_name, field_names, span_scores, pred_count,
text_len, text_tokens, original_text, start_mapping, end_mapping,
threshold, metadata, cls_fields, include_confidence, include_spans
)
def _extract_entities(
self,
entity_names: List[str],
span_scores: torch.Tensor,
text_len: int,
text_tokens: List[str],
text: str,
start_map: List[int],
end_map: List[int],
threshold: float,
metadata: Dict,
include_confidence: bool,
include_spans: bool
) -> List[Dict]:
"""Extract entity results."""
scores = span_scores[0, :, -text_len:]
entity_results = OrderedDict()
for name in metadata.get("entity_order", entity_names):
if name not in entity_names:
continue
idx = entity_names.index(name)
meta = metadata.get("entity_metadata", {}).get(name, {})
meta_threshold = meta.get("threshold")
ent_threshold = meta_threshold if meta_threshold is not None else threshold
dtype = meta.get("dtype", "list")
spans = self._find_spans(
scores[idx], ent_threshold, text_len, text,
start_map, end_map
)
if dtype == "list":
entity_results[name] = self._format_spans(spans, include_confidence, include_spans)
else:
if spans:
text_val, conf, char_start, char_end = spans[0]
if include_spans and include_confidence:
entity_results[name] = {
"text": text_val,
"confidence": conf,
"start": char_start,
"end": char_end
}
elif include_spans:
entity_results[name] = {
"text": text_val,
"start": char_start,
"end": char_end
}
elif include_confidence:
entity_results[name] = {"text": text_val, "confidence": conf}
else:
entity_results[name] = text_val
else:
entity_results[name] = "" if not include_spans and not include_confidence else None
return [entity_results] if entity_results else []
def _extract_relations(
self,
rel_name: str,
field_names: List[str],
span_scores: torch.Tensor,
count: int,
text_len: int,
text_tokens: List[str],
text: str,
start_map: List[int],
end_map: List[int],
threshold: float,
metadata: Dict,
include_confidence: bool,
include_spans: bool
) -> List[Union[Tuple[str, str], Dict]]:
"""Extract relation results with optional confidence and position info."""
instances = []
rel_threshold = threshold
if rel_name in metadata.get("relation_metadata", {}):
meta_threshold = metadata["relation_metadata"][rel_name].get("threshold")
rel_threshold = meta_threshold if meta_threshold is not None else threshold
ordered_fields = metadata.get("field_orders", {}).get(rel_name, field_names)
for inst in range(count):
scores = span_scores[inst, :, -text_len:]
values = []
field_data = [] # Store full data for each field
for fname in ordered_fields:
if fname not in field_names:
continue
fidx = field_names.index(fname)
spans = self._find_spans(
scores[fidx], rel_threshold, text_len, text,
start_map, end_map
)
if spans:
text_val, conf, char_start, char_end = spans[0]
values.append(text_val)
field_data.append({
"text": text_val,
"confidence": conf,
"start": char_start,
"end": char_end
})
else:
values.append(None)
field_data.append(None)
if len(values) == 2 and values[0] and values[1]:
# Format based on flags
if include_spans and include_confidence:
instances.append({
"head": field_data[0],
"tail": field_data[1]
})
elif include_spans:
instances.append({
"head": {"text": field_data[0]["text"], "start": field_data[0]["start"], "end": field_data[0]["end"]},
"tail": {"text": field_data[1]["text"], "start": field_data[1]["start"], "end": field_data[1]["end"]}
})
elif include_confidence:
instances.append({
"head": {"text": field_data[0]["text"], "confidence": field_data[0]["confidence"]},
"tail": {"text": field_data[1]["text"], "confidence": field_data[1]["confidence"]}
})
else:
# Original tuple format for backward compatibility
instances.append((values[0], values[1]))
return instances
def _extract_structures(
self,
struct_name: str,
field_names: List[str],
span_scores: torch.Tensor,
count: int,
text_len: int,
text_tokens: List[str],
text: str,
start_map: List[int],
end_map: List[int],
threshold: float,
metadata: Dict,
cls_fields: Dict,
include_confidence: bool,
include_spans: bool
) -> List[Dict]:
"""Extract structure results with optional position tracking."""
instances = []
ordered_fields = metadata.get("field_orders", {}).get(struct_name, field_names)
for inst in range(count):
scores = span_scores[inst, :, -text_len:]
instance = OrderedDict()
for fname in ordered_fields:
if fname not in field_names:
continue
fidx = field_names.index(fname)
field_key = f"{struct_name}.{fname}"
meta = metadata.get("field_metadata", {}).get(field_key, {})
meta_threshold = meta.get("threshold")
field_threshold = meta_threshold if meta_threshold is not None else threshold
dtype = meta.get("dtype", "list")
validators = meta.get("validators", [])
if field_key in cls_fields:
# Classification field - no span positions needed
choices = cls_fields[field_key]
prefix_scores = span_scores[inst, fidx, :-text_len]
if dtype == "list":
selected = []
seen = set()
for choice in choices:
if choice in seen:
continue
idx = self._find_choice_idx(choice, text_tokens[:-text_len])
if idx >= 0 and idx < prefix_scores.shape[0]:
score = prefix_scores[idx, 0].item()
if score >= field_threshold:
if include_confidence:
selected.append({"text": choice, "confidence": score})
else:
selected.append(choice)
seen.add(choice)
instance[fname] = selected
else:
best = None
best_score = -1.0
for choice in choices:
idx = self._find_choice_idx(choice, text_tokens[:-text_len])
if idx >= 0 and idx < prefix_scores.shape[0]:
score = prefix_scores[idx, 0].item()
if score > best_score:
best_score = score
best = choice
if best and best_score >= field_threshold:
if include_confidence:
instance[fname] = {"text": best, "confidence": best_score}
else:
instance[fname] = best
else:
instance[fname] = None
else:
# Regular span field - track positions
spans = self._find_spans(
scores[fidx], field_threshold, text_len, text,
start_map, end_map
)
if validators:
spans = [s for s in spans if all(v.validate(s[0]) for v in validators)]
if dtype == "list":
instance[fname] = self._format_spans(spans, include_confidence, include_spans)
else:
if spans:
text_val, conf, char_start, char_end = spans[0]
if include_spans and include_confidence:
instance[fname] = {
"text": text_val,
"confidence": conf,
"start": char_start,
"end": char_end
}
elif include_spans:
instance[fname] = {
"text": text_val,
"start": char_start,
"end": char_end
}
elif include_confidence:
instance[fname] = {"text": text_val, "confidence": conf}
else:
instance[fname] = text_val
else:
instance[fname] = None
# Only add if has content
if any(v is not None and v != [] for v in instance.values()):
instances.append(instance)
return instances
def _find_spans(
self,
scores: torch.Tensor,
threshold: float,
text_len: int,
text: str,
start_map: List[int],
end_map: List[int]
) -> List[Tuple[str, float, int, int]]:
"""Find valid spans above threshold. Returns (text, confidence, char_start, char_end)."""
valid = torch.where(scores >= threshold)
starts, widths = valid
spans = []
for start, width in zip(starts.tolist(), widths.tolist()):
end = start + width + 1
if 0 <= start < text_len and end <= text_len:
try:
char_start = start_map[start]
char_end = end_map[end - 1]
text_span = text[char_start:char_end].strip()
except (IndexError, KeyError):
continue
if text_span:
conf = scores[start, width].item()
spans.append((text_span, conf, char_start, char_end))
return spans
def _format_spans(
self,
spans: List[Tuple],
include_confidence: bool,
include_spans: bool = False
) -> Union[List[str], List[Dict], List[Tuple]]:
"""Format spans with overlap removal and optional position info."""
if not spans:
return []
sorted_spans = sorted(spans, key=lambda x: x[1], reverse=True)
selected = []
for text, conf, start, end in sorted_spans:
overlap = any(not (end <= s[2] or start >= s[3]) for s in selected)
if not overlap:
selected.append((text, conf, start, end))
# Format based on flags
if include_spans and include_confidence:
return [{"text": s[0], "confidence": s[1], "start": s[2], "end": s[3]} for s in selected]
elif include_spans:
return [{"text": s[0], "start": s[2], "end": s[3]} for s in selected]
elif include_confidence:
return [{"text": s[0], "confidence": s[1]} for s in selected]
else:
return [s[0] for s in selected]
def _find_choice_idx(self, choice: str, tokens: List[str]) -> int:
"""Find index of choice in tokens."""
choice_lower = choice.lower()
for i, tok in enumerate(tokens):
if tok.lower() == choice_lower or choice_lower in tok.lower():
return i
return -1
# =========================================================================
# Result Formatting
# =========================================================================
def format_results(
self,
results: Dict,
include_confidence: bool = False,
requested_relations: List[str] = None,
classification_tasks: List[str] = None
) -> Dict[str, Any]:
"""Format extraction results."""
formatted = {}
relations = {}
requested_relations = requested_relations or []
classification_tasks = classification_tasks or []
for key, value in results.items():
# Check if this is a classification task (takes priority)
is_classification = key in classification_tasks
# Check if this is a relation
is_relation = False
if not is_classification:
# Check if key is in requested_relations (this takes priority)
if key in requested_relations:
is_relation = True
# Otherwise, check the value structure
elif isinstance(value, list) and len(value) > 0:
# Check for tuple format: [(head, tail), ...]
if isinstance(value[0], tuple) and len(value[0]) == 2:
is_relation = True
# Check for dict format with head/tail keys: [{"head": ..., "tail": ...}, ...]
elif isinstance(value[0], dict) and "head" in value[0] and "tail" in value[0]:
is_relation = True
if is_classification:
# This is a classification task - format and add to formatted dict directly
if isinstance(value, list):
# Multi-label classification
if include_confidence:
formatted[key] = [{"label": l, "confidence": c} for l, c in value]
else:
formatted[key] = [l for l, _ in value]
elif isinstance(value, tuple):
# Single-label classification
label, conf = value
formatted[key] = {"label": label, "confidence": conf} if include_confidence else label
else:
formatted[key] = value
elif is_relation:
# This is a relation - store in relations dict, not formatted
# Relations should always be lists, but handle edge cases defensively
if isinstance(value, list):
relations[key] = value
else:
# Unexpected non-list value for relation - convert to empty list
relations[key] = []
elif isinstance(value, list):
if len(value) == 0:
if key == "entities":
formatted[key] = {}
else:
formatted[key] = value
elif isinstance(value[0], dict):
if key == "entities":
formatted[key] = self._format_entity_dict(value[0], include_confidence)
else:
formatted[key] = [self._format_struct(v, include_confidence) for v in value]
elif isinstance(value[0], tuple):
if include_confidence:
formatted[key] = [{"label": l, "confidence": c} for l, c in value]
else:
formatted[key] = [l for l, _ in value]
else:
formatted[key] = value
elif isinstance(value, tuple):
label, conf = value
formatted[key] = {"label": label, "confidence": conf} if include_confidence else label
elif isinstance(value, dict):
formatted[key] = self._format_struct(value, include_confidence)
else:
formatted[key] = value
# Add all requested relations (including empty ones)
for rel in requested_relations:
if rel not in relations:
relations[rel] = []
# Only add relation_extraction if we have relations
if relations:
formatted["relation_extraction"] = relations
return formatted
def _format_entity_dict(self, entities: Dict, include_confidence: bool) -> Dict:
formatted = {}
for name, spans in entities.items():
if isinstance(spans, list):
unique = []
seen = set()
for span in spans:
if isinstance(span, tuple):
text, conf, start, end = span
if text and text.lower() not in seen:
seen.add(text.lower())
unique.append({"text": text, "confidence": conf} if include_confidence else text)
elif isinstance(span, dict):
# Handle dict format (with confidence/spans)
text = span.get("text", "")
if text and text.lower() not in seen:
seen.add(text.lower())
unique.append(span)
else:
# Handle string format
if span and span.lower() not in seen:
seen.add(span.lower())
unique.append(span)
formatted[name] = unique
elif isinstance(spans, tuple):
text, conf, _, _ = spans
formatted[name] = {"text": text, "confidence": conf} if include_confidence and text else text
else:
formatted[name] = spans or None
return formatted
def _format_struct(self, struct: Dict, include_confidence: bool) -> Dict:
formatted = {}
for field, value in struct.items():
if isinstance(value, list):
unique = []
seen = set()
for v in value:
if isinstance(v, tuple):
text, conf, _, _ = v
if text and text.lower() not in seen:
seen.add(text.lower())
unique.append({"text": text, "confidence": conf} if include_confidence else text)
elif isinstance(v, dict):
# Handle dict format (with confidence/spans)
text = v.get("text", "")
if text and text.lower() not in seen:
seen.add(text.lower())
unique.append(v)
else:
# Handle string format
if v and v.lower() not in seen:
seen.add(v.lower())
unique.append(v)
formatted[field] = unique
elif isinstance(value, tuple):
text, conf, _, _ = value
formatted[field] = {"text": text, "confidence": conf} if include_confidence and text else text
elif value:
formatted[field] = value
else:
formatted[field] = None
return formatted
# =========================================================================
# Convenience Methods (route through batch)
# =========================================================================
def extract(self, text: str, schema, threshold: float = 0.5,
format_results: bool = True, include_confidence: bool = False,
include_spans: bool = False) -> Dict:
"""Extract from single text."""
return self.batch_extract([text], schema, 1, threshold, 0, format_results, include_confidence, include_spans)[0]
def extract_entities(self, text: str, entity_types, threshold: float = 0.5,
format_results: bool = True, include_confidence: bool = False,
include_spans: bool = False) -> Dict:
"""Extract entities from text."""
schema = self.create_schema().entities(entity_types)
return self.extract(text, schema, threshold, format_results, include_confidence, include_spans)
def batch_extract_entities(self, texts: List[str], entity_types, batch_size: int = 8,
threshold: float = 0.5, format_results: bool = True,
include_confidence: bool = False, include_spans: bool = False) -> List[Dict]:
"""Batch extract entities."""
schema = self.create_schema().entities(entity_types)
return self.batch_extract(texts, schema, batch_size, threshold, 0, format_results, include_confidence, include_spans)
def classify_text(self, text: str, tasks: Dict, threshold: float = 0.5,
format_results: bool = True, include_confidence: bool = False,
include_spans: bool = False) -> Dict:
"""Classify text."""
schema = self.create_schema()
for name, config in tasks.items():
if isinstance(config, dict) and "labels" in config:
cfg = config.copy()
labels = cfg.pop("labels")
schema.classification(name, labels, **cfg)
else:
schema.classification(name, config)
return self.extract(text, schema, threshold, format_results, include_confidence, include_spans)
def batch_classify_text(self, texts: List[str], tasks: Dict, batch_size: int = 8,
threshold: float = 0.5, format_results: bool = True,
include_confidence: bool = False, include_spans: bool = False) -> List[Dict]:
"""Batch classify texts."""
schema = self.create_schema()
for name, config in tasks.items():
if isinstance(config, dict) and "labels" in config:
cfg = config.copy()
labels = cfg.pop("labels")
schema.classification(name, labels, **cfg)
else:
schema.classification(name, config)
return self.batch_extract(texts, schema, batch_size, threshold, 0, format_results, include_confidence, include_spans)
def extract_json(self, text: str, structures: Dict, threshold: float = 0.5,
format_results: bool = True, include_confidence: bool = False,
include_spans: bool = False) -> Dict:
"""Extract structured data."""
schema = self.create_schema()
for parent, fields in structures.items():
builder = schema.structure(parent)
for spec in fields:
name, dtype, choices, desc = self._parse_field_spec(spec)
builder.field(name, dtype=dtype, choices=choices, description=desc)
return self.extract(text, schema, threshold, format_results, include_confidence, include_spans)
def batch_extract_json(self, texts: List[str], structures: Dict, batch_size: int = 8,
threshold: float = 0.5, format_results: bool = True,
include_confidence: bool = False, include_spans: bool = False) -> List[Dict]:
"""Batch extract structured data."""
schema = self.create_schema()
for parent, fields in structures.items():
builder = schema.structure(parent)
for spec in fields:
name, dtype, choices, desc = self._parse_field_spec(spec)
builder.field(name, dtype=dtype, choices=choices, description=desc)
return self.batch_extract(texts, schema, batch_size, threshold, 0, format_results, include_confidence, include_spans)
def extract_relations(self, text: str, relation_types, threshold: float = 0.5,
format_results: bool = True, include_confidence: bool = False,
include_spans: bool = False) -> Dict:
"""Extract relations."""
schema = self.create_schema().relations(relation_types)
return self.extract(text, schema, threshold, format_results, include_confidence, include_spans)
def batch_extract_relations(self, texts: List[str], relation_types, batch_size: int = 8,
threshold: float = 0.5, format_results: bool = True,
include_confidence: bool = False, include_spans: bool = False) -> List[Dict]:
"""Batch extract relations."""
schema = self.create_schema().relations(relation_types)
return self.batch_extract(texts, schema, batch_size, threshold, 0, format_results, include_confidence, include_spans)
def _parse_field_spec(self, spec: Union[str, Dict]) -> Tuple[str, str, Optional[List[str]], Optional[str]]:
"""Parse field specification string or dictionary.
Format: "name::dtype::choices::description" where all parts after name are optional.
- dtype: 'str' for single value, 'list' for multiple values
- choices: [option1|option2|...] for enumerated options
- description: free text description
Examples:
"restaurant::str::Restaurant name"
"seating::[indoor|outdoor|bar]::Seating preference"
"dietary::[vegetarian|vegan|gluten-free|none]::list::Dietary restrictions"
"""
if isinstance(spec, dict):
return (
spec.get("name", ""),
spec.get("dtype", "list"),
spec.get("choices"),
spec.get("description")
)
parts = spec.split('::')
name = parts[0]
dtype, choices, desc = "list", None, None
dtype_explicitly_set = False
if len(parts) == 1:
return name, dtype, choices, desc
for part in parts[1:]:
if part in ['str', 'list']:
dtype = part
dtype_explicitly_set = True
elif part.startswith('[') and part.endswith(']'):
choices = [c.strip() for c in part[1:-1].split('|')]
# Only default to "str" if dtype wasn't explicitly set
if not dtype_explicitly_set:
dtype = "str"
else:
desc = part
return name, dtype, choices, desc
# Aliases
BuilderExtractor = GLiNER2
SchemaBuilder = Schema
JsonStructBuilder = StructureBuilder