This commit is contained in:
lovebird 2026-03-06 12:59:32 +01:00
parent afa907065b
commit 9b32c8dd29
30 changed files with 16618 additions and 0 deletions

97
packages/GLiNER2/.gitignore vendored Normal file
View File

@ -0,0 +1,97 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Ruff
.ruff_cache/
# IDEs
.idea/
.vscode/
*.swp
*.swo
*~
# OS files
.DS_Store
Thumbs.db
# Model files (typically large)
*.pt
*.pth
*.bin
*.onnx
*.safetensors
# Logs
*.log
test_api_client.py

201
packages/GLiNER2/LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

1031
packages/GLiNER2/README.md Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,79 @@
# PyPI Release Guide for GLiNER2
## Prerequisites
- [ ] Python 3.8+ installed
- [ ] PyPI account with API token configured
- [ ] Write access to the repository
## Release Steps
### 1. Update Version
Update version in `gliner2/__init__.py`:
```python
__version__ = "1.0.1" # New version
```
### 2. Build Package
```bash
# Install build tools
pip install build twine
# Clean previous builds
rm -rf dist/ build/ *.egg-info/
# Build package
python -m build
```
### 3. Test Build (Optional)
```bash
# Test on TestPyPI first
twine upload --repository testpypi dist/*
# Install and test
pip install --index-url https://test.pypi.org/simple/ gliner2
```
### 4. Upload to PyPI
```bash
# Upload to production PyPI
twine upload dist/*
```
### 5. Create GitHub Release
1. Go to GitHub repository → Releases
2. Click "Create a new release"
3. Tag: `v1.0.1` (matching version)
4. Title: `GLiNER2 v1.0.1`
5. Description: Summary of changes
6. Attach built wheels from `dist/` folder
### 6. Verify Release
```bash
# Install from PyPI
pip install gliner2==1.0.1
# Test basic functionality
python -c "from gliner2 import GLiNER2; print('✓ Import successful')"
```
## Troubleshooting
- **Authentication error**: Configure PyPI token in `~/.pypirc` or use `--username __token__`
- **File exists error**: Version already exists on PyPI, increment version number
- **Build fails**: Check `pyproject.toml` dependencies and Python version compatibility
## Checklist
- [ ] Version updated in `__init__.py`
- [ ] Package builds without errors
- [ ] Uploaded to PyPI successfully
- [ ] GitHub release created
- [ ] Installation verified

View File

@ -0,0 +1,401 @@
"""
Statistical benchmark with confidence intervals and p-values.
Micro-benchmarks: interleaved old/new in same process paired t-test.
End-to-end: saves raw timings to JSON for cross-process Welch's t-test.
Usage:
# Baseline
git stash
python benchmark_statistical.py --tag baseline --n 300
git stash pop
# Optimized
python benchmark_statistical.py --tag optimized --n 300
# Compare
python benchmark_statistical.py --compare baseline optimized
"""
import argparse
import json
import math
import random
import time
import statistics
import sys
from collections import OrderedDict
import torch
from scipy import stats as sp_stats
# ─── Helpers ──────────────────────────────────────────────────────
def sync():
if torch.cuda.is_available():
torch.cuda.synchronize()
def ci95(data):
"""95% CI half-width using t-distribution."""
n = len(data)
if n < 2:
return 0.0
se = statistics.stdev(data) / math.sqrt(n)
t_crit = sp_stats.t.ppf(0.975, df=n - 1)
return t_crit * se
def collect(fn, n_warmup, n_iter):
"""Run fn with warmup, return list of times in ms."""
for _ in range(n_warmup):
fn()
sync()
times = []
for _ in range(n_iter):
sync()
t0 = time.perf_counter()
fn()
sync()
times.append((time.perf_counter() - t0) * 1000)
return times
def paired_test(old_times, new_times):
"""Paired t-test on matched samples. Returns (t_stat, p_value, mean_diff, ci95_diff)."""
diffs = [o - n for o, n in zip(old_times, new_times)]
n = len(diffs)
mean_d = statistics.mean(diffs)
se_d = statistics.stdev(diffs) / math.sqrt(n)
t_stat = mean_d / se_d if se_d > 0 else 0
p_val = 2 * sp_stats.t.sf(abs(t_stat), df=n - 1)
hw = ci95(diffs)
return t_stat, p_val, mean_d, hw
def welch_test(a, b):
"""Welch's t-test (unequal variance). Returns (t_stat, p_value)."""
t_stat, p_val = sp_stats.ttest_ind(a, b, equal_var=False)
return t_stat, p_val
def fmt_p(p):
if p < 0.001:
return f"{p:.2e}"
return f"{p:.4f}"
# ─── End-to-end benchmark ────────────────────────────────────────
def run_e2e(n_iter, n_warmup):
"""Run end-to-end scenarios, return dict of {name: [times]}."""
from gliner2 import GLiNER2
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
model = model.to(device)
model.eval()
text1 = "Apple CEO Tim Cook announced the iPhone 15 launch in Cupertino on September 12, 2023."
ents = ["company", "person", "product", "location", "date"]
texts8 = [
"Apple CEO Tim Cook announced the iPhone 15 launch in Cupertino.",
"Google's Sundar Pichai spoke at the conference in Mountain View.",
"Microsoft released Windows 11 in Redmond last year.",
"Amazon founder Jeff Bezos invested in Blue Origin in Seattle.",
"Tesla CEO Elon Musk unveiled the Cybertruck at the Fremont factory.",
"Meta's Mark Zuckerberg presented Quest 3 in Menlo Park.",
"NVIDIA's Jensen Huang showcased the H100 GPU at GTC in San Jose.",
"OpenAI CEO Sam Altman launched GPT-4 in San Francisco.",
]
long_text = (
"Apple Inc., headquartered in Cupertino, California, is a multinational technology company "
"founded by Steve Jobs, Steve Wozniak, and Ronald Wayne in April 1976. The company designs, "
"develops, and sells consumer electronics, computer software, and online services. Tim Cook "
"has served as CEO since August 2011. Apple's main products include the iPhone, iPad, Mac, "
"Apple Watch, and AirPods. The company also operates services including the App Store, "
"Apple Music, iCloud, and Apple TV Plus. In 2023, Apple reported annual revenue of $383 "
"billion, making it the world's largest technology company by revenue. The company employs "
"over 160,000 people worldwide."
)
ents6 = ["company", "person", "product", "location", "date", "monetary_value"]
text_struct = "John Smith, aged 35, is a software engineer at Google in Mountain View."
schema_struct = model.create_schema()
schema_struct.structure("person").field("name").field("age").field("job_title").field("company").field("location")
text_rel = "Apple CEO Tim Cook announced the iPhone 15 launch in Cupertino on September 12."
rels = ["CEO_of", "located_in", "announced_on"]
results = OrderedDict()
scenarios = [
("single_entity", lambda: model.extract_entities(text1, ents)),
("single_structure", lambda: model.extract(text_struct, schema_struct)),
("single_relation", lambda: model.extract_relations(text_rel, rels)),
("batch8_entity", lambda: model.batch_extract_entities(texts8, ents, batch_size=8)),
("long_text_entity", lambda: model.extract_entities(long_text, ents6)),
]
for name, fn in scenarios:
print(f" Running {name} (n={n_iter})...", end=" ", flush=True)
times = collect(fn, n_warmup, n_iter)
results[name] = times
m, hw = statistics.mean(times), ci95(times)
print(f"{m:.2f} ± {hw:.2f} ms")
return results
# ─── Micro-benchmarks (interleaved old/new) ──────────────────────
def run_micro(n_iter, n_warmup):
"""Run micro-benchmarks with interleaved old/new for paired comparison."""
import copy
from gliner2 import GLiNER2
from gliner2.training.trainer import ExtractorCollator
from torch.utils.data import DataLoader
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
model = model.to(device)
model.eval()
tokenizer = model.processor.tokenizer
results = OrderedDict()
# --- OPT-1: Token ID lookup ---
special_set_str = {"[P]", "[C]", "[E]", "[R]", "[L]"}
special_ids = frozenset(tokenizer.convert_tokens_to_ids(t) for t in special_set_str)
dummy_ids = list(range(200))
def opt1_old():
for tid in dummy_ids:
tok = tokenizer.convert_ids_to_tokens(tid)
_ = tok in special_set_str
def opt1_new():
for tid in dummy_ids:
_ = tid in special_ids
print(" OPT-1 Token ID lookup...", end=" ", flush=True)
old_t, new_t = _interleaved(opt1_old, opt1_new, n_warmup, n_iter)
results["OPT-1 Token ID lookup"] = {"old": old_t, "new": new_t}
_print_paired(old_t, new_t)
# --- OPT-3: Avoid retokenization ---
test_text = "Apple CEO Tim Cook announced the iPhone 15 launch in Cupertino on September 12."
dummy_map = list(range(15))
def opt3_old():
return len(model.processor._tokenize_text(test_text))
def opt3_new():
return len(dummy_map)
print(" OPT-3 Avoid retokenization...", end=" ", flush=True)
old_t, new_t = _interleaved(opt3_old, opt3_new, n_warmup, n_iter)
results["OPT-3 Avoid retokenization"] = {"old": old_t, "new": new_t}
_print_paired(old_t, new_t)
# --- OPT-4: Deepcopy ---
schema_dict = {
"json_structures": [{"person": {"name": "", "age": "", "job": ""}}],
"entities": {"company": "", "location": ""},
"relations": [], "classifications": [],
}
record = {"text": "Apple CEO Tim Cook announced iPhone 15." * 3, "schema": schema_dict}
def opt4_old():
return copy.deepcopy(record)
def opt4_new():
return {"text": record["text"], "schema": copy.deepcopy(record["schema"])}
print(" OPT-4 Deepcopy...", end=" ", flush=True)
old_t, new_t = _interleaved(opt4_old, opt4_new, n_warmup, n_iter)
results["OPT-4 Deepcopy"] = {"old": old_t, "new": new_t}
_print_paired(old_t, new_t)
# --- OPT-6: Token cache ---
special_tokens = ["[SEP_STRUCT]", "[SEP_TEXT]", "[P]", "[C]", "[E]", "[R]", "[L]",
"[EXAMPLE]", "[OUTPUT]", "[DESCRIPTION]", "(", ")", ",", "|"]
cache = {tok: tokenizer.tokenize(tok) for tok in special_tokens}
test_tokens = special_tokens * 10
def opt6_old():
for tok in test_tokens:
tokenizer.tokenize(tok)
def opt6_new():
for tok in test_tokens:
if tok in cache:
_ = cache[tok]
else:
tokenizer.tokenize(tok)
print(" OPT-6 Token cache...", end=" ", flush=True)
old_t, new_t = _interleaved(opt6_old, opt6_new, n_warmup, n_iter)
results["OPT-6 Token cache"] = {"old": old_t, "new": new_t}
_print_paired(old_t, new_t)
# --- OPT-12: Skip DataLoader ---
collator = ExtractorCollator(model.processor, is_training=False)
text_norm = "Apple CEO Tim Cook announced the iPhone 15 launch in Cupertino on September 12, 2023."
schema_e = model.create_schema().entities(["company", "person", "product", "location", "date"])
sd = schema_e.build()
for c in sd.get("classifications", []):
c.setdefault("true_label", ["N/A"])
small_dataset = [(text_norm, sd)]
def opt12_old():
loader = DataLoader(small_dataset, batch_size=8, shuffle=False,
num_workers=0, collate_fn=collator)
return list(loader)
def opt12_new():
return [collator(small_dataset)]
print(" OPT-12 Skip DataLoader...", end=" ", flush=True)
old_t, new_t = _interleaved(opt12_old, opt12_new, n_warmup, n_iter)
results["OPT-12 Skip DataLoader"] = {"old": old_t, "new": new_t}
_print_paired(old_t, new_t)
return results
def _interleaved(old_fn, new_fn, n_warmup, n_iter):
"""Run old/new interleaved to eliminate ordering effects. Returns paired lists."""
# Warmup both
for _ in range(n_warmup):
old_fn()
new_fn()
sync()
old_times = []
new_times = []
for _ in range(n_iter):
# Randomize order each iteration to eliminate systematic bias
if random.random() < 0.5:
sync(); t0 = time.perf_counter(); old_fn(); sync()
old_times.append((time.perf_counter() - t0) * 1000)
sync(); t0 = time.perf_counter(); new_fn(); sync()
new_times.append((time.perf_counter() - t0) * 1000)
else:
sync(); t0 = time.perf_counter(); new_fn(); sync()
new_times.append((time.perf_counter() - t0) * 1000)
sync(); t0 = time.perf_counter(); old_fn(); sync()
old_times.append((time.perf_counter() - t0) * 1000)
return old_times, new_times
def _print_paired(old_t, new_t):
m_old, m_new = statistics.mean(old_t), statistics.mean(new_t)
t_stat, p_val, mean_diff, hw = paired_test(old_t, new_t)
speedup = m_old / m_new if m_new > 0 else float('inf')
print(f"{m_old:.4f} -> {m_new:.4f} ms ({speedup:.1f}x) "
f"diff={mean_diff:.4f}±{hw:.4f}ms p={fmt_p(p_val)}")
# ─── Compare mode ────────────────────────────────────────────────
def compare(baseline_path, optimized_path):
"""Compare two end-to-end result files with Welch's t-test."""
with open(baseline_path) as f:
baseline = json.load(f)
with open(optimized_path) as f:
optimized = json.load(f)
print(f"\nBaseline: {baseline_path} (device={baseline['device']}, n={baseline.get('n', '?')})")
print(f"Optimized: {optimized_path} (device={optimized['device']}, n={optimized.get('n', '?')})")
print(f"\n{'Scenario':<25} {'Baseline':>18} {'Optimized':>18} {'Diff':>14} {'Speedup':>8} {'p-value':>10}")
print("=" * 100)
for name in baseline["e2e"]:
b = baseline["e2e"][name]
o = optimized["e2e"][name]
m_b, ci_b = statistics.mean(b), ci95(b)
m_o, ci_o = statistics.mean(o), ci95(o)
diff = m_b - m_o
diff_ci = math.sqrt(ci_b**2 + ci_o**2) # approximate CI of difference
speedup = m_b / m_o if m_o > 0 else float('inf')
t_stat, p_val = welch_test(b, o)
sig = "*" if p_val < 0.05 else " "
if p_val < 0.01:
sig = "**"
if p_val < 0.001:
sig = "***"
print(f"{name:<25} {m_b:>7.2f}±{ci_b:>5.2f}ms {m_o:>7.2f}±{ci_o:>5.2f}ms "
f"{diff:>+6.2f}±{diff_ci:>4.2f}ms {speedup:>7.3f}x {fmt_p(p_val):>9}{sig}")
# Micro-benchmarks (if present in optimized)
if "micro" in optimized:
print(f"\n{'Component':<30} {'Old':>16} {'New':>16} {'Diff (paired)':>18} {'Speedup':>8} {'p-value':>10}")
print("=" * 105)
for name, data in optimized["micro"].items():
old_t = data["old"]
new_t = data["new"]
m_old, ci_old = statistics.mean(old_t), ci95(old_t)
m_new, ci_new = statistics.mean(new_t), ci95(new_t)
t_stat, p_val, mean_diff, hw = paired_test(old_t, new_t)
speedup = m_old / m_new if m_new > 0 else float('inf')
sig = "*" if p_val < 0.05 else " "
if p_val < 0.01: sig = "**"
if p_val < 0.001: sig = "***"
print(f"{name:<30} {m_old:>6.4f}±{ci_old:>6.4f}ms {m_new:>6.4f}±{ci_new:>6.4f}ms "
f"{mean_diff:>+7.4f}±{hw:>6.4f}ms {speedup:>7.1f}x {fmt_p(p_val):>9}{sig}")
# ─── Main ────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--tag", help="Tag for this run (baseline or optimized)")
parser.add_argument("--n", type=int, default=300, help="Iterations per scenario")
parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations")
parser.add_argument("--compare", nargs=2, metavar=("BASELINE", "OPTIMIZED"),
help="Compare two result files")
args = parser.parse_args()
if args.compare:
compare(
f"bench_stats_{args.compare[0]}.json",
f"bench_stats_{args.compare[1]}.json"
)
return
if not args.tag:
parser.error("--tag is required (or use --compare)")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
print(f"Iterations: {args.n}, Warmup: {args.warmup}\n")
output = {"tag": args.tag, "device": device, "n": args.n}
# End-to-end
print("END-TO-END BENCHMARKS")
print("-" * 60)
e2e = run_e2e(args.n, args.warmup)
output["e2e"] = e2e
# Micro-benchmarks (only meaningful for optimized run since we inline both versions)
print("\nCOMPONENT MICRO-BENCHMARKS (interleaved old/new)")
print("-" * 60)
micro = run_micro(args.n, args.warmup)
output["micro"] = {k: v for k, v in micro.items()}
out_path = f"bench_stats_{args.tag}.json"
with open(out_path, "w") as f:
json.dump(output, f)
print(f"\nRaw timings saved to {out_path}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,23 @@
__version__ = "1.2.4"
from .inference.engine import GLiNER2, RegexValidator
from .model import Extractor, ExtractorConfig
from .api_client import (
GLiNER2API,
GLiNER2APIError,
AuthenticationError,
ValidationError,
ServerError,
)
from .training.lora import (
LoRAConfig,
LoRAAdapterConfig,
LoRALayer,
load_lora_adapter,
save_lora_adapter,
unload_lora_adapter,
has_lora_adapter,
apply_lora_to_model,
merge_lora_weights,
unmerge_lora_weights,
)

View File

@ -0,0 +1,989 @@
"""
GLiNER2 API Client
This module provides an API-based wrapper for GLiNER2 that mirrors the local
model interface. It allows seamless switching between local and API-based
inference.
Usage:
>>> from gliner2 import GLiNER2
>>>
>>> # Load from API (uses environment variable for API key)
>>> extractor = GLiNER2.from_api()
>>>
>>> # Use exactly like local model
>>> results = extractor.extract_entities(
... "Apple released iPhone 15 in September 2023.",
... ["company", "product", "date"]
... )
"""
from __future__ import annotations
import os
import logging
import warnings
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Union, Literal
from urllib.parse import urljoin
from urllib3.util import Retry
import requests
from requests.adapters import HTTPAdapter
logger = logging.getLogger(__name__)
class GLiNER2APIError(Exception):
"""Base exception for GLiNER2 API errors."""
def __init__(self, message: str, status_code: Optional[int] = None, response_data: Optional[Dict] = None):
super().__init__(message)
self.status_code = status_code
self.response_data = response_data
class AuthenticationError(GLiNER2APIError):
"""Raised when API key is invalid or expired."""
pass
class ValidationError(GLiNER2APIError):
"""Raised when request data is invalid."""
pass
class ServerError(GLiNER2APIError):
"""Raised when server encounters an error."""
pass
class StructureBuilderAPI:
"""
Builder for structured data schemas for API-based extraction.
This mirrors the interface of StructureBuilder from the local model.
"""
def __init__(self, schema: 'SchemaAPI', parent: str):
self.schema = schema
self.parent = parent
self.fields = OrderedDict()
self.field_order = []
self._finished = False
def field(
self,
name: str,
dtype: Literal["str", "list"] = "list",
choices: Optional[List[str]] = None,
description: Optional[str] = None,
threshold: Optional[float] = None,
validators: Optional[List] = None
) -> 'StructureBuilderAPI':
"""Add a field to the structured data."""
# Warn if validators are used (not supported in API mode)
if validators:
warnings.warn(
f"Field '{name}': RegexValidator is not supported in API mode. "
"Validators will be ignored. Use local model for regex-based filtering.",
UserWarning,
stacklevel=2
)
self.fields[name] = {
"dtype": dtype,
"choices": choices,
"description": description,
"threshold": threshold
}
self.field_order.append(name)
return self
def _auto_finish(self):
"""Automatically finish this structure when needed."""
if not self._finished:
# Convert fields to API format
# Use dict format if any field has threshold or choices (advanced features)
# Otherwise use simple string format for backwards compatibility
field_specs = []
for name in self.field_order:
config = self.fields[name]
# Check if advanced features are used
has_threshold = config.get('threshold') is not None
has_choices = config.get('choices') is not None
if has_threshold or has_choices:
# Use dict format for advanced features
field_dict = {"name": name, "dtype": config['dtype']}
if config.get('description'):
field_dict["description"] = config['description']
if has_threshold:
field_dict["threshold"] = config['threshold']
if has_choices:
field_dict["choices"] = config['choices']
field_specs.append(field_dict)
else:
# Use simple string format: "name::type::description"
spec = f"{name}::{config['dtype']}"
if config.get('description'):
spec += f"::{config['description']}"
field_specs.append(spec)
self.schema._structures[self.parent] = field_specs
self._finished = True
def __getattr__(self, name):
"""Auto-finish when any schema method is called."""
if hasattr(self.schema, name):
self._auto_finish()
return getattr(self.schema, name)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
class SchemaAPI:
"""Schema builder for API-based extraction tasks."""
def __init__(self):
self._entities = None
self._entity_dtype = "list"
self._entity_threshold = None
self._classifications = {}
self._structures = {}
self._relations = None
self._relation_threshold = None
self._active_structure_builder = None
def entities(
self,
entity_types: Union[str, List[str], Dict[str, Union[str, Dict]]],
dtype: Literal["str", "list"] = "list",
threshold: Optional[float] = None
) -> 'SchemaAPI':
"""Add entity extraction task."""
if self._active_structure_builder:
self._active_structure_builder._auto_finish()
self._active_structure_builder = None
# Normalize to list or dict
if isinstance(entity_types, str):
self._entities = [entity_types]
elif isinstance(entity_types, list):
self._entities = entity_types
elif isinstance(entity_types, dict):
self._entities = entity_types
self._entity_dtype = dtype
self._entity_threshold = threshold
return self
def classification(
self,
task: str,
labels: Union[List[str], Dict[str, str]],
multi_label: bool = False,
cls_threshold: float = 0.5,
**kwargs
) -> 'SchemaAPI':
"""Add a text classification task."""
if self._active_structure_builder:
self._active_structure_builder._auto_finish()
self._active_structure_builder = None
# Parse labels
if isinstance(labels, dict):
label_names = list(labels.keys())
else:
label_names = labels
self._classifications[task] = {
"labels": label_names,
"multi_label": multi_label,
"cls_threshold": cls_threshold
}
return self
def structure(self, name: str) -> StructureBuilderAPI:
"""Start building a structured data schema."""
if self._active_structure_builder:
self._active_structure_builder._auto_finish()
self._active_structure_builder = StructureBuilderAPI(self, name)
return self._active_structure_builder
def relations(
self,
relation_types: Union[str, List[str], Dict[str, Union[str, Dict]]],
threshold: Optional[float] = None
) -> 'SchemaAPI':
"""
Add relation extraction task.
Args:
relation_types: Relation types to extract. Can be:
- str: Single relation type
- List[str]: Multiple relation types
- Dict[str, str]: Relation types with descriptions
- Dict[str, Dict]: Relation types with full configuration
threshold: Default confidence threshold for relations.
Returns:
Self for method chaining.
"""
if self._active_structure_builder:
self._active_structure_builder._auto_finish()
self._active_structure_builder = None
# Normalize to list or dict
if isinstance(relation_types, str):
self._relations = [relation_types]
elif isinstance(relation_types, list):
self._relations = relation_types
elif isinstance(relation_types, dict):
self._relations = relation_types
self._relation_threshold = threshold
return self
def build(self) -> Dict[str, Any]:
"""Build the schema for API request."""
if self._active_structure_builder:
self._active_structure_builder._auto_finish()
self._active_structure_builder = None
schema = {}
if self._entities is not None:
schema["entities"] = self._entities
schema["entity_dtype"] = self._entity_dtype
if self._entity_threshold is not None:
schema["entity_threshold"] = self._entity_threshold
if self._classifications:
schema["classifications"] = self._classifications
if self._structures:
schema["structures"] = self._structures
if self._relations is not None:
schema["relations"] = self._relations
if self._relation_threshold is not None:
schema["relation_threshold"] = self._relation_threshold
return schema
class GLiNER2API:
"""
API-based GLiNER2 client that mirrors the local model interface.
This class provides the same methods as GLiNER2 but makes HTTP requests
to the API endpoint instead of running local inference.
Attributes:
api_key: API authentication key
base_url: API base URL
timeout: Request timeout in seconds
max_retries: Maximum number of retries for failed requests
"""
DEFAULT_BASE_URL = "https://api.fastino.ai"
def __init__(
self,
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
timeout: float = 30.0,
max_retries: int = 3,
):
"""
Initialize the GLiNER2 API client.
Args:
api_key: API authentication key. If not provided, reads from
PIONEER_API_KEY environment variable.
api_base_url: Override the default API base URL.
timeout: Request timeout in seconds.
max_retries: Maximum number of retries for failed requests.
Raises:
ValueError: If no API key is provided and PIONEER_API_KEY is not set.
"""
# Read API key from environment if not provided
if api_key is None:
api_key = os.environ.get("PIONEER_API_KEY")
if api_key is None:
raise ValueError(
"API key must be provided either as an argument or via "
"PIONEER_API_KEY environment variable"
)
self.api_key = api_key
self.base_url = api_base_url or os.environ.get(
"GLINER2_API_BASE_URL", self.DEFAULT_BASE_URL
)
self.timeout = timeout
self.max_retries = max_retries
# Setup HTTP session with retry logic
self.session = requests.Session()
self.session.headers.update({
"X-API-Key": api_key,
"Content-Type": "application/json",
})
# Configure retry strategy
retry_strategy = Retry(
total=max_retries,
backoff_factor=1, # 1s, 2s, 4s backoff
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["POST"],
)
adapter = HTTPAdapter(max_retries=retry_strategy)
self.session.mount("https://", adapter)
self.session.mount("http://", adapter)
logger.debug(f"Initialized GLiNER2API for {self.base_url}")
def _make_request(
self,
task: str,
text: Union[str, List[str]],
schema: Union[List[str], Dict],
threshold: float = 0.5,
include_confidence: bool = False,
include_spans: bool = False,
format_results: bool = True,
) -> Dict[str, Any]:
"""
Make an HTTP request to the GLiNER-2 API.
Args:
task: Task type (extract_entities, classify_text, extract_json, schema)
text: Text to process (string or list for batch)
schema: Schema for extraction
threshold: Confidence threshold
include_confidence: Whether to include confidence scores in results
include_spans: Whether to include character-level start/end positions
format_results: Whether to format results (False for raw extraction data)
Returns:
API response result
Raises:
GLiNER2APIError: If request fails
"""
# Ensure base_url ends with / for proper joining
base = self.base_url.rstrip('/') + '/'
url = urljoin(base, "gliner-2")
payload = {
"task": task,
"text": text,
"schema": schema,
"threshold": threshold,
"include_confidence": include_confidence,
"include_spans": include_spans,
"format_results": format_results,
}
logger.debug(f"Making POST request to {url}")
try:
response = self.session.post(
url,
json=payload,
timeout=self.timeout,
)
logger.debug(f"Response status: {response.status_code}")
# Handle different error codes
if response.status_code == 401:
error_data = response.json() if response.content else None
error_msg = (
error_data.get("detail", "Invalid or expired API key")
if error_data else "Invalid or expired API key"
)
raise AuthenticationError(error_msg, response_data=error_data)
elif response.status_code in (400, 422):
error_data = response.json() if response.content else None
error_msg = (
error_data.get("detail", "Request validation failed")
if error_data else "Request validation failed"
)
raise ValidationError(
error_msg,
status_code=response.status_code,
response_data=error_data,
)
elif response.status_code >= 500:
error_data = response.json() if response.content else None
error_msg = (
error_data.get("detail", "Server error occurred")
if error_data else "Server error occurred"
)
raise ServerError(
error_msg,
status_code=response.status_code,
response_data=error_data,
)
elif not response.ok:
error_data = response.json() if response.content else None
error_msg = (
error_data.get("detail", f"Request failed with status {response.status_code}")
if error_data else f"Request failed with status {response.status_code}"
)
raise GLiNER2APIError(
error_msg,
status_code=response.status_code,
response_data=error_data,
)
data = response.json()
return data.get("result", data)
except requests.exceptions.Timeout:
raise GLiNER2APIError(f"Request timed out after {self.timeout}s")
except requests.exceptions.ConnectionError as e:
raise GLiNER2APIError(f"Connection error: {str(e)}")
except requests.exceptions.RequestException as e:
raise GLiNER2APIError(f"Request failed: {str(e)}")
def create_schema(self) -> SchemaAPI:
"""Create a new schema for defining extraction tasks."""
return SchemaAPI()
# -------------------------------------------------------------------------
# Entity Extraction Methods
# -------------------------------------------------------------------------
def extract_entities(
self,
text: str,
entity_types: Union[List[str], Dict[str, Union[str, Dict]]],
threshold: float = 0.5,
format_results: bool = True,
include_confidence: bool = False,
include_spans: bool = False
) -> Dict[str, Any]:
"""
Extract entities from text.
Args:
text: Input text to extract entities from.
entity_types: List of entity types or dict with descriptions.
threshold: Minimum confidence threshold.
format_results: Whether to format results. If False, returns raw extraction data.
include_confidence: Whether to include confidence scores in results.
include_spans: Whether to include character-level start/end positions.
Returns:
Dictionary with "entities" key containing extracted entities.
If include_confidence=True, entity values include confidence scores.
If include_spans=True, entity values include start/end positions.
If format_results=False, returns raw extraction data with positions.
"""
# Normalize entity types to list
if isinstance(entity_types, dict):
entities = list(entity_types.keys())
else:
entities = entity_types
result = self._make_request(
task="extract_entities",
text=text,
schema=entities,
threshold=threshold,
include_confidence=include_confidence,
include_spans=include_spans,
format_results=format_results,
)
# Wrap result in expected format if needed (only for formatted results)
if format_results and isinstance(result, dict) and "entities" not in result:
return {"entities": result}
return result
def batch_extract_entities(
self,
texts: List[str],
entity_types: Union[List[str], Dict[str, Union[str, Dict]]],
batch_size: int = 8,
threshold: float = 0.5,
format_results: bool = True,
include_confidence: bool = False,
include_spans: bool = False
) -> List[Dict[str, Any]]:
"""
Batch extract entities from multiple texts.
Args:
texts: List of input texts.
entity_types: List of entity types or dict with descriptions.
batch_size: Batch size (used by API for optimization).
threshold: Minimum confidence threshold.
format_results: Whether to format results. If False, returns raw extraction data.
include_confidence: Whether to include confidence scores.
include_spans: Whether to include character-level start/end positions.
Returns:
List of dictionaries with "entities" key.
If include_confidence=True, entity values include confidence scores.
If include_spans=True, entity values include start/end positions.
If format_results=False, returns raw extraction data with positions.
"""
# Normalize entity types to list
if isinstance(entity_types, dict):
entities = list(entity_types.keys())
else:
entities = entity_types
result = self._make_request(
task="extract_entities",
text=texts,
schema=entities,
threshold=threshold,
include_confidence=include_confidence,
include_spans=include_spans,
format_results=format_results,
)
# Ensure result is a list
if isinstance(result, dict):
return [result]
return result
# -------------------------------------------------------------------------
# Text Classification Methods
# -------------------------------------------------------------------------
def classify_text(
self,
text: str,
tasks: Dict[str, Union[List[str], Dict[str, Any]]],
threshold: float = 0.5,
format_results: bool = True,
include_confidence: bool = False,
include_spans: bool = False
) -> Dict[str, Any]:
"""
Classify text into categories.
Args:
text: Text to classify.
tasks: Classification tasks where keys are task names.
threshold: Confidence threshold.
format_results: Whether to format results. If False, returns raw extraction data.
include_confidence: Whether to include confidence scores.
include_spans: Whether to include character-level start/end positions.
Returns:
Classification results keyed by task name.
If include_confidence=True, results include confidence scores.
If format_results=False, returns raw extraction data.
"""
# Convert tasks to API format
# For classify_text task, schema should be {"categories": [...]}
# But for multi-task, we need to use the schema task
if len(tasks) == 1:
# Single task - use classify_text endpoint
task_name = list(tasks.keys())[0]
task_config = tasks[task_name]
if isinstance(task_config, dict) and "labels" in task_config:
categories = task_config["labels"]
else:
categories = task_config
result = self._make_request(
task="classify_text",
text=text,
schema={"categories": categories},
threshold=threshold,
include_confidence=include_confidence,
include_spans=include_spans,
format_results=format_results,
)
# Wrap result with task name (only for formatted results)
if format_results and isinstance(result, dict) and task_name not in result:
return {task_name: result.get("classification", result)}
return result
else:
# Multiple tasks - use schema endpoint
schema = {"classifications": tasks}
result = self._make_request(
task="schema",
text=text,
schema=schema,
threshold=threshold,
include_confidence=include_confidence,
include_spans=include_spans,
format_results=format_results,
)
return result
def batch_classify_text(
self,
texts: List[str],
tasks: Dict[str, Union[List[str], Dict[str, Any]]],
batch_size: int = 8,
threshold: float = 0.5,
format_results: bool = True,
include_confidence: bool = False,
include_spans: bool = False
) -> List[Dict[str, Any]]:
"""
Batch classify multiple texts.
Args:
texts: List of texts to classify.
tasks: Classification tasks.
batch_size: Batch size.
threshold: Confidence threshold.
format_results: Whether to format results. If False, returns raw extraction data.
include_confidence: Whether to include confidence scores.
include_spans: Whether to include character-level start/end positions.
Returns:
List of classification results.
If include_confidence=True, results include confidence scores.
If format_results=False, returns raw extraction data.
"""
# Use schema task for batch classification
schema = {"classifications": tasks}
result = self._make_request(
task="schema",
text=texts,
schema=schema,
threshold=threshold,
include_confidence=include_confidence,
include_spans=include_spans,
format_results=format_results,
)
if isinstance(result, dict):
return [result]
return result
# -------------------------------------------------------------------------
# JSON Extraction Methods
# -------------------------------------------------------------------------
def extract_json(
self,
text: str,
structures: Dict[str, List[str]],
threshold: float = 0.5,
format_results: bool = True,
include_confidence: bool = False,
include_spans: bool = False
) -> Dict[str, Any]:
"""
Extract structured data from text.
Args:
text: Text to extract data from.
structures: Structure definitions with field specs.
threshold: Minimum confidence threshold.
format_results: Whether to format results. If False, returns raw extraction data.
include_confidence: Whether to include confidence scores.
include_spans: Whether to include character-level start/end positions.
Returns:
Extracted structures keyed by structure name.
If include_confidence=True, field values include confidence scores.
If include_spans=True, field values include start/end positions.
If format_results=False, returns raw extraction data with positions.
"""
result = self._make_request(
task="extract_json",
text=text,
schema=structures,
threshold=threshold,
include_confidence=include_confidence,
include_spans=include_spans,
format_results=format_results,
)
return result
def batch_extract_json(
self,
texts: List[str],
structures: Dict[str, List[str]],
batch_size: int = 8,
threshold: float = 0.5,
format_results: bool = True,
include_confidence: bool = False,
include_spans: bool = False
) -> List[Dict[str, Any]]:
"""
Batch extract structured data from multiple texts.
Args:
texts: List of texts.
structures: Structure definitions.
batch_size: Batch size.
threshold: Confidence threshold.
format_results: Whether to format results. If False, returns raw extraction data.
include_confidence: Whether to include confidence scores.
include_spans: Whether to include character-level start/end positions.
Returns:
List of extracted structures.
If include_confidence=True, field values include confidence scores.
If include_spans=True, field values include start/end positions.
If format_results=False, returns raw extraction data with positions.
"""
result = self._make_request(
task="extract_json",
text=texts,
schema=structures,
threshold=threshold,
include_confidence=include_confidence,
include_spans=include_spans,
format_results=format_results,
)
if isinstance(result, dict):
return [result]
return result
# -------------------------------------------------------------------------
# Relation Extraction Methods
# -------------------------------------------------------------------------
def extract_relations(
self,
text: str,
relation_types: Union[str, List[str], Dict[str, Union[str, Dict]]],
threshold: float = 0.5,
format_results: bool = True,
include_confidence: bool = False,
include_spans: bool = False
) -> Dict[str, Any]:
"""
Extract relations between entities from text.
Args:
text: Input text to extract relations from.
relation_types: Relation types to extract. Can be:
- str: Single relation type
- List[str]: Multiple relation types
- Dict[str, str]: Relation types with descriptions
- Dict[str, Dict]: Relation types with full configuration
threshold: Minimum confidence threshold.
format_results: Whether to format results. If False, returns raw extraction data.
include_confidence: Whether to include confidence scores in results.
include_spans: Whether to include character-level start/end positions.
Returns:
Dictionary with "relation_extraction" key containing extracted relations.
Relations are grouped by type with tuples (source, target).
Format: {"relation_extraction": {"relation_name": [("source", "target"), ...]}}
"""
# Build schema with relations
schema = self.create_schema().relations(relation_types).build()
result = self._make_request(
task="schema",
text=text,
schema=schema,
threshold=threshold,
include_confidence=include_confidence,
include_spans=include_spans,
format_results=format_results,
)
return result
def batch_extract_relations(
self,
texts: List[str],
relation_types: Union[str, List[str], Dict[str, Union[str, Dict]]],
batch_size: int = 8,
threshold: float = 0.5,
format_results: bool = True,
include_confidence: bool = False,
include_spans: bool = False
) -> List[Dict[str, Any]]:
"""
Batch extract relations from multiple texts.
Args:
texts: List of input texts.
relation_types: Relation types to extract.
batch_size: Batch size (used by API for optimization).
threshold: Minimum confidence threshold.
format_results: Whether to format results.
include_confidence: Whether to include confidence scores.
include_spans: Whether to include character-level start/end positions.
Returns:
List of dictionaries with "relation_extraction" key.
Format: [{"relation_extraction": {"relation_name": [("source", "target"), ...]}}]
"""
# Build schema with relations
schema = self.create_schema().relations(relation_types).build()
result = self._make_request(
task="schema",
text=texts,
schema=schema,
threshold=threshold,
include_confidence=include_confidence,
include_spans=include_spans,
format_results=format_results,
)
# Ensure result is a list
if isinstance(result, dict):
return [result]
return result
# -------------------------------------------------------------------------
# General Extraction Methods
# -------------------------------------------------------------------------
def extract(
self,
text: str,
schema: Union[SchemaAPI, Dict[str, Any]],
threshold: float = 0.5,
format_results: bool = True,
include_confidence: bool = False,
include_spans: bool = False
) -> Dict[str, Any]:
"""
Extract information from text using a schema.
Args:
text: Input text to extract from.
schema: Schema defining what to extract.
threshold: Minimum confidence threshold.
format_results: Whether to format results. If False, returns raw extraction data.
include_confidence: Whether to include confidence scores.
include_spans: Whether to include character-level start/end positions.
Returns:
Extraction results organized by task name.
If include_confidence=True, values include confidence scores.
If include_spans=True, values include start/end positions.
If format_results=False, returns raw extraction data with positions.
"""
# Build schema dict if needed
if isinstance(schema, SchemaAPI):
schema_dict = schema.build()
elif hasattr(schema, 'build'):
schema_dict = schema.build()
else:
schema_dict = schema
# Validate schema has at least one extraction task
has_any_task = any(
key in schema_dict
for key in ["entities", "classifications", "structures", "relations"]
)
if not has_any_task:
raise ValueError("Schema must contain at least one extraction task")
# Always use schema task to preserve all metadata (thresholds, dtypes, etc.)
return self._make_request(
task="schema",
text=text,
schema=schema_dict,
threshold=threshold,
include_confidence=include_confidence,
include_spans=include_spans,
format_results=format_results,
)
def batch_extract(
self,
texts: List[str],
schemas: Union[SchemaAPI, List[SchemaAPI], Dict[str, Any], List[Dict[str, Any]]],
batch_size: int = 8,
threshold: float = 0.5,
format_results: bool = True,
include_confidence: bool = False,
include_spans: bool = False
) -> List[Dict[str, Any]]:
"""
Extract information from multiple texts.
Args:
texts: List of input texts.
schemas: Single schema for all texts or list of schemas.
batch_size: Batch size.
threshold: Confidence threshold.
format_results: Whether to format results. If False, returns raw extraction data.
include_confidence: Whether to include confidence scores.
include_spans: Whether to include character-level start/end positions.
Returns:
List of extraction results.
If include_confidence=True, values include confidence scores.
If include_spans=True, values include start/end positions.
If format_results=False, returns raw extraction data with positions.
"""
if not texts:
return []
# Handle schema variations
if isinstance(schemas, list):
if len(schemas) != len(texts):
raise ValueError(
f"Number of schemas ({len(schemas)}) must match number of texts ({len(texts)})"
)
# Warn user about multi-schema batch limitation
warnings.warn(
"Multi-schema batch (different schemas per text) is not natively supported by the API. "
"Each text will be processed individually, which may be slower than single-schema batch. "
"For better performance, use the same schema for all texts.",
UserWarning,
stacklevel=2
)
# Process each text with its schema individually
results = []
for text, schema in zip(texts, schemas):
results.append(self.extract(text, schema, threshold, include_confidence=include_confidence, include_spans=include_spans, format_results=format_results))
return results
# Single schema for all texts
if isinstance(schemas, SchemaAPI):
schema_dict = schemas.build()
elif hasattr(schemas, 'build'):
schema_dict = schemas.build()
else:
schema_dict = schemas
return self._make_request(
task="schema",
text=texts,
schema=schema_dict,
threshold=threshold,
include_confidence=include_confidence,
include_spans=include_spans,
format_results=format_results,
)
# -------------------------------------------------------------------------
# Utility Methods
# -------------------------------------------------------------------------
def close(self):
"""Close the HTTP session."""
self.session.close()
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.close()

View File

@ -0,0 +1 @@
from .engine import RegexValidator, GLiNER2

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,191 @@
"""
Pydantic models for validating schema input from JSON/dict.
This module provides validation models for creating GLiNER2 schemas
from JSON or dictionary inputs.
"""
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, field_validator, model_validator
class FieldInput(BaseModel):
"""Validates a single structure field.
Args:
name: Field name
dtype: Data type - 'str' for single value, 'list' for multiple values
choices: Optional list of valid choices for classification-style fields
description: Optional description of the field
"""
name: str = Field(..., min_length=1, description="Field name")
dtype: Literal["str", "list"] = Field(default="list", description="Data type")
choices: Optional[List[str]] = Field(default=None, description="Valid choices")
description: Optional[str] = Field(default=None, description="Field description")
@field_validator('choices')
@classmethod
def validate_choices(cls, v: Optional[List[str]]) -> Optional[List[str]]:
"""Ensure choices list is not empty if provided."""
if v is not None and len(v) == 0:
raise ValueError("choices must contain at least one option")
return v
class StructureInput(BaseModel):
"""Validates a structure block.
Args:
fields: List of field definitions
"""
fields: List[FieldInput] = Field(..., min_length=1, description="List of fields")
class ClassificationInput(BaseModel):
"""Validates a classification task.
Args:
task: Task name
labels: List of classification labels
multi_label: Whether multiple labels can be selected
"""
task: str = Field(..., min_length=1, description="Task name")
labels: List[str] = Field(..., min_length=2, description="Classification labels")
multi_label: bool = Field(default=False, description="Multi-label classification")
@field_validator('labels')
@classmethod
def validate_labels(cls, v: List[str]) -> List[str]:
"""Ensure labels are unique and non-empty."""
if len(v) != len(set(v)):
raise ValueError("labels must be unique")
if any(not label.strip() for label in v):
raise ValueError("labels cannot be empty strings")
return v
class SchemaInput(BaseModel):
"""Root schema validation model.
Args:
entities: List of entity types or dict mapping types to descriptions
structures: Dict mapping structure names to structure definitions
classifications: List of classification task definitions
relations: List of relation types or dict mapping types to config
"""
entities: Optional[Union[List[str], Dict[str, str]]] = Field(
default=None,
description="Entity types"
)
structures: Optional[Dict[str, StructureInput]] = Field(
default=None,
description="Structure definitions"
)
classifications: Optional[List[ClassificationInput]] = Field(
default=None,
description="Classification tasks"
)
relations: Optional[Union[List[str], Dict[str, Dict[str, Any]]]] = Field(
default=None,
description="Relation types"
)
@field_validator('entities')
@classmethod
def validate_entities(
cls,
v: Optional[Union[List[str], Dict[str, str]]]
) -> Optional[Union[List[str], Dict[str, str]]]:
"""Validate entities format."""
if v is None:
return v
if isinstance(v, list):
if len(v) == 0:
raise ValueError("entities list cannot be empty")
if any(not entity.strip() for entity in v):
raise ValueError("entity names cannot be empty strings")
if len(v) != len(set(v)):
raise ValueError("entity names must be unique")
elif isinstance(v, dict):
if len(v) == 0:
raise ValueError("entities dict cannot be empty")
if any(not key.strip() for key in v.keys()):
raise ValueError("entity names cannot be empty strings")
return v
@field_validator('structures')
@classmethod
def validate_structures(
cls,
v: Optional[Dict[str, StructureInput]]
) -> Optional[Dict[str, StructureInput]]:
"""Validate structures format."""
if v is None:
return v
if len(v) == 0:
raise ValueError("structures dict cannot be empty")
if any(not key.strip() for key in v.keys()):
raise ValueError("structure names cannot be empty strings")
return v
@field_validator('classifications')
@classmethod
def validate_classifications(
cls,
v: Optional[List[ClassificationInput]]
) -> Optional[List[ClassificationInput]]:
"""Validate classifications format."""
if v is None:
return v
if len(v) == 0:
raise ValueError("classifications list cannot be empty")
# Check for duplicate task names
task_names = [cls_task.task for cls_task in v]
if len(task_names) != len(set(task_names)):
raise ValueError("classification task names must be unique")
return v
@field_validator('relations')
@classmethod
def validate_relations(
cls,
v: Optional[Union[List[str], Dict[str, Dict[str, Any]]]]
) -> Optional[Union[List[str], Dict[str, Dict[str, Any]]]]:
"""Validate relations format."""
if v is None:
return v
if isinstance(v, list):
if len(v) == 0:
raise ValueError("relations list cannot be empty")
if any(not rel.strip() for rel in v):
raise ValueError("relation names cannot be empty strings")
if len(v) != len(set(v)):
raise ValueError("relation names must be unique")
elif isinstance(v, dict):
if len(v) == 0:
raise ValueError("relations dict cannot be empty")
if any(not key.strip() for key in v.keys()):
raise ValueError("relation names cannot be empty strings")
return v
@model_validator(mode='after')
def validate_at_least_one_section(self) -> 'SchemaInput':
"""Ensure at least one section is provided."""
if all(
getattr(self, field) is None
for field in ['entities', 'structures', 'classifications', 'relations']
):
raise ValueError(
"At least one of entities, structures, classifications, "
"or relations must be provided"
)
return self

View File

@ -0,0 +1,249 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def create_mlp(input_dim, intermediate_dims, output_dim, dropout=0.1, activation="gelu", add_layer_norm=False):
"""
Creates a multi-layer perceptron (MLP) with specified dimensions and activation functions.
"""
activation_mapping = {
"relu": nn.ReLU,
"tanh": nn.Tanh,
"sigmoid": nn.Sigmoid,
"leaky_relu": nn.LeakyReLU,
"gelu": nn.GELU
}
layers = []
in_dim = input_dim
for dim in intermediate_dims:
layers.append(nn.Linear(in_dim, dim))
if add_layer_norm:
layers.append(nn.LayerNorm(dim))
layers.append(activation_mapping[activation]())
if dropout > 0:
layers.append(nn.Dropout(dropout))
in_dim = dim
layers.append(nn.Linear(in_dim, output_dim))
return nn.Sequential(*layers)
class DownscaledTransformer(nn.Module):
def __init__(self, input_size, hidden_size, num_heads=4, num_layers=2, dropout=0.1):
"""
Initializes a downscaled transformer with specified parameters.
"""
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_layers = num_layers
self.in_projector = nn.Linear(input_size, hidden_size)
encoder = nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=num_heads,
dim_feedforward=hidden_size * 2,
dropout=dropout,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder, num_layers=num_layers)
self.out_projector = create_mlp(
input_dim=hidden_size + input_size,
intermediate_dims=[input_size, input_size],
output_dim=input_size,
dropout=0.,
activation="relu",
add_layer_norm=False
)
def forward(self, x):
"""
Args:
x (Tensor): Input tensor of shape (L, M, input_size).
Returns:
Tensor: Output tensor of shape (L, M, input_size).
"""
original_x = x
# Project input to hidden size.
x = self.in_projector(x)
# Apply transformer encoder.xx
x = self.transformer(x)
# Concatenate original input with transformer output.
x = torch.cat([x, original_x], dim=-1)
# Project back to input size.
x = self.out_projector(x)
return x
class CountLSTM(nn.Module):
def __init__(self, hidden_size, max_count=20):
"""
Initializes the module with a learned positional embedding for count steps and a GRU.
"""
super().__init__()
self.hidden_size = hidden_size
self.max_count = max_count
# Learned positional embeddings for count steps: shape (max_count, hidden_size)
self.pos_embedding = nn.Embedding(max_count, hidden_size)
# Use a GRU layer; input shape is (seq_len, batch, input_size)
self.gru = nn.GRU(input_size=hidden_size, hidden_size=hidden_size)
# Projector layer: combines GRU output with original embeddings.
self.projector = create_mlp(
input_dim=hidden_size * 2,
intermediate_dims=[hidden_size * 4],
output_dim=hidden_size,
dropout=0.,
activation="relu",
add_layer_norm=False
)
def forward(self, pc_emb: torch.Tensor, gold_count_val: int) -> torch.Tensor:
"""
Args:
pc_emb (Tensor): Field embeddings of shape (M, hidden_size).
gold_count_val (int): Predicted count value (number of steps).
Returns:
Tensor: Count-aware structure embeddings of shape (gold_count_val, M, hidden_size).
"""
M, D = pc_emb.shape
# Cap gold_count_val by max_count.
gold_count_val = min(gold_count_val, self.max_count)
# Create a sequence of count indices: shape (gold_count_val,)
count_indices = torch.arange(gold_count_val, device=pc_emb.device)
# Get positional embeddings for each count: (gold_count_val, hidden_size)
pos_seq = self.pos_embedding(count_indices)
# Expand pos_seq over the batch dimension: (gold_count_val, M, hidden_size)
pos_seq = pos_seq.unsqueeze(1).expand(gold_count_val, M, D)
# Initialize the GRU hidden state with the field embeddings.
h0 = pc_emb.unsqueeze(0) # shape: (1, M, hidden_size)
# Run the GRU over the count sequence.
output, _ = self.gru(pos_seq, h0)
# Concatenate the GRU outputs with the original field embeddings.
return self.projector(torch.cat([output, pc_emb.unsqueeze(0).expand_as(output)], dim=-1))
class CountLSTMv2(nn.Module):
def __init__(self, hidden_size, max_count=20):
super().__init__()
self.hidden_size = hidden_size
self.max_count = max_count
self.pos_embedding = nn.Embedding(max_count, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size)
self.transformer = DownscaledTransformer(
hidden_size,
hidden_size=128,
num_heads=4,
num_layers=2,
dropout=0.1,
)
# NOTE: gold_count_val is now a 0-D Tensor, not a Python int
def forward(self, pc_emb: torch.Tensor, gold_count_val: int) -> torch.Tensor:
M, D = pc_emb.size() # symbolic sizes
# clamp without dropping to Python
gold_count_val = min(gold_count_val, self.max_count)
# build the *full* index vector once, then slice ONNX supports both ops
full_idx = torch.arange(self.max_count, device=pc_emb.device)
count_idx = full_idx[:gold_count_val] # (gold_count_val,)
pos_seq = self.pos_embedding(count_idx) # (gold_count_val, D)
pos_seq = pos_seq.unsqueeze(1).expand(-1, M, -1) # (gold_count_val, M, D)
h0 = pc_emb.unsqueeze(0) # (1, M, D)
output, _ = self.gru(pos_seq, h0) # (gold_count_val, M, D)
pc_broadcast = pc_emb.unsqueeze(0).expand_as(output)
return self.transformer(output + pc_broadcast)
class CountLSTMoE(nn.Module):
"""
Count-aware module with a Mixture-of-Experts projector.
Args
----
hidden_size : int
Model dimensionality (D).
max_count : int
Maximum #count steps L.
n_experts : int, optional
Number of FFN experts (default = 4).
ffn_mult : int, optional
Expansion factor for expert bottleneck (default = 2 inner = 2 D).
dropout : float, optional
Drop-out used inside expert FFNs.
"""
def __init__(self,
hidden_size: int,
max_count: int = 20,
n_experts: int = 4,
ffn_mult: int = 2,
dropout: float = 0.1):
super().__init__()
self.hidden_size, self.max_count, self.n_experts = (
hidden_size, max_count, n_experts)
# ───── positional encoding + recurrent core ─────
self.pos_embedding = nn.Embedding(max_count, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size)
# ───── expert parameters (all packed in two tensors) ─────
inner = hidden_size * ffn_mult
# W1 : [E, D, inner] b1 : [E, inner]
self.w1 = nn.Parameter(torch.empty(n_experts, hidden_size, inner))
self.b1 = nn.Parameter(torch.zeros(n_experts, inner))
# W2 : [E, inner, D] b2 : [E, D]
self.w2 = nn.Parameter(torch.empty(n_experts, inner, hidden_size))
self.b2 = nn.Parameter(torch.zeros(n_experts, hidden_size))
# better than default init for large mats
nn.init.xavier_uniform_(self.w1)
nn.init.xavier_uniform_(self.w2)
self.dropout = nn.Dropout(dropout)
# ───── router / gating network ─────
self.router = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, n_experts), # logits
nn.Softmax(dim=-1), # gates sum-to-1
)
# ---------------------------------------------------
def forward(self, pc_emb: torch.Tensor, gold_count_val: int) -> torch.Tensor:
"""
pc_emb : [M, D] field embeddings
gold_count_val : int (# count steps to unroll)
returns : [L, M, D] count-aware embeddings
"""
M, D = pc_emb.shape
L = min(gold_count_val, self.max_count)
idx = torch.arange(L, device=pc_emb.device)
pos_seq = self.pos_embedding(idx).unsqueeze(1).expand(L, M, D)
h0 = pc_emb.unsqueeze(0) # [1, M, D]
h, _ = self.gru(pos_seq, h0) # [L, M, D]
# ───── routing / gating ─────
gates = self.router(h) # [L, M, E]
# ───── expert FFN: run *all* experts in parallel ─────
# 1st linear
x = torch.einsum('lmd,edh->lmeh', h, self.w1) + self.b1 # [L, M, E, inner]
x = F.gelu(x)
x = self.dropout(x)
# 2nd linear
x = torch.einsum('lmeh,ehd->lmed', x, self.w2) + self.b2 # [L, M, E, D]
# ───── mixture weighted by gates ─────
out = (gates.unsqueeze(-1) * x).sum(dim=2) # [L, M, D]
return out

View File

@ -0,0 +1,692 @@
"""
GLiNER2 Extractor Model with Optimized Batch Processing
This module contains the core Extractor model that accepts PreprocessedBatch
directly for efficient GPU-only forward passes.
"""
import os
import tempfile
from typing import Dict, List, Any, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from gliner.modeling.span_rep import SpanRepLayer
from gliner2.layers import CountLSTMoE, CountLSTM, create_mlp, CountLSTMv2
from gliner2.processor import SchemaTransformer, PreprocessedBatch, SamplingConfig
from safetensors.torch import save_file, load_file
from transformers import (
PretrainedConfig,
PreTrainedModel,
AutoModel,
AutoConfig,
AutoTokenizer,
)
class ExtractorConfig(PretrainedConfig):
"""Configuration for the Extractor model."""
model_type = "extractor"
def __init__(
self,
model_name: str = "bert-base-uncased",
max_width: int = 8,
counting_layer: str = "count_lstm",
token_pooling: str = "first",
**kwargs
):
super().__init__(**kwargs)
self.model_name = model_name
self.max_width = max_width
self.counting_layer = counting_layer
self.token_pooling = token_pooling
class Extractor(PreTrainedModel):
"""
GLiNER2 Extractor Model.
This model accepts PreprocessedBatch for efficient training.
Use processor.collate_fn_train() to create batches.
Example:
>>> processor = SchemaTransformer(model_name)
>>> model = Extractor.from_pretrained(repo_id)
>>>
>>> # Training
>>> loader = DataLoader(dataset, collate_fn=processor.collate_fn_train)
>>> for batch in loader:
... batch = batch.to(device)
... loss = model(batch)["total_loss"]
"""
config_class = ExtractorConfig
def __init__(self, config: ExtractorConfig, encoder_config=None, tokenizer=None):
super().__init__(config)
self.config = config
self.max_width = config.max_width
# Initialize processor
if tokenizer is not None:
self.processor = SchemaTransformer(
tokenizer=tokenizer,
token_pooling=config.token_pooling
)
else:
self.processor = SchemaTransformer(
config.model_name,
token_pooling=config.token_pooling
)
# Load encoder
if encoder_config is not None:
self.encoder = AutoModel.from_config(encoder_config, trust_remote_code=True)
else:
self.encoder = AutoModel.from_pretrained(config.model_name, trust_remote_code=True)
self.encoder.resize_token_embeddings(len(self.processor.tokenizer))
self.hidden_size = self.encoder.config.hidden_size
# Span representation layer
self.span_rep = SpanRepLayer(
span_mode="markerV0",
hidden_size=self.hidden_size,
max_width=self.max_width,
dropout=0.1,
)
# Classifier for classification tasks
self.classifier = create_mlp(
input_dim=self.hidden_size,
intermediate_dims=[self.hidden_size * 2],
output_dim=1,
dropout=0.,
activation="relu",
add_layer_norm=False
)
# Count prediction layer
self.count_pred = create_mlp(
input_dim=self.hidden_size,
intermediate_dims=[self.hidden_size * 2],
output_dim=20,
dropout=0.,
activation="relu",
add_layer_norm=False
)
# Count embedding module
if config.counting_layer == "count_lstm":
self.count_embed = CountLSTM(self.hidden_size)
elif config.counting_layer == "count_lstm_moe":
self.count_embed = CountLSTMoE(
hidden_size=self.hidden_size,
n_experts=4,
ffn_mult=2,
dropout=0.1
)
elif config.counting_layer == "count_lstm_v2":
self.count_embed = CountLSTMv2(hidden_size=self.hidden_size)
# LoRA adapter state
self._lora_layers = {}
self._adapter_config = None
self._print_config(config)
def _print_config(self, config):
print("=" * 60)
print("🧠 Model Configuration")
print("=" * 60)
print(f"Encoder model : {config.model_name}")
print(f"Counting layer : {config.counting_layer}")
print(f"Token pooling : {config.token_pooling}")
print("=" * 60)
# =========================================================================
# Main Forward Pass
# =========================================================================
def forward(
self,
batch: PreprocessedBatch,
return_individual_losses: bool = False
) -> Dict[str, torch.Tensor]:
"""
Forward pass on preprocessed batch.
Args:
batch: PreprocessedBatch from processor.collate_fn_train()
return_individual_losses: If True, return per-sample losses
Returns:
Dict with:
- total_loss: Sum of all losses
- classification_loss: Classification task loss
- structure_loss: Span extraction loss
- count_loss: Count prediction loss
- batch_size: Number of valid samples
"""
if len(batch) == 0:
return self._empty_loss_dict()
device = next(self.parameters()).device
batch = batch.to(device)
# Encode batch through transformer
all_token_embs, all_schema_embs = self._encode_batch(batch)
# Compute losses for each sample
cls_losses = []
struct_losses = []
count_losses = []
individual = []
valid_samples = 0
for i in range(len(batch)):
try:
sample_losses = self._compute_sample_loss(
token_embeddings=all_token_embs[i],
embs_per_schema=all_schema_embs[i],
task_types=batch.task_types[i],
structure_labels=batch.structure_labels[i],
device=device
)
cls_losses.append(sample_losses["classification"])
struct_losses.append(sample_losses["structure"])
count_losses.append(sample_losses["count"])
if return_individual_losses:
individual.append({
"total_loss": (
sample_losses["classification"] +
sample_losses["structure"] +
sample_losses["count"]
).item(),
"classification_loss": sample_losses["classification"].item(),
"structure_loss": sample_losses["structure"].item(),
"count_loss": sample_losses["count"].item(),
})
valid_samples += 1
except Exception as e:
print(f"Error processing sample {i}: {e}")
zero = torch.tensor(0.0, device=device)
cls_losses.append(zero)
struct_losses.append(zero)
count_losses.append(zero)
if return_individual_losses:
individual.append({
"total_loss": 0.0,
"classification_loss": 0.0,
"structure_loss": 0.0,
"count_loss": 0.0,
"error": str(e)
})
if valid_samples == 0:
result = self._empty_loss_dict()
if return_individual_losses:
result["individual_losses"] = individual
return result
# Aggregate losses
total_cls = torch.stack(cls_losses).sum()
total_struct = torch.stack(struct_losses).sum()
total_count = torch.stack(count_losses).sum()
total_loss = total_cls + total_struct + total_count
result = {
"total_loss": total_loss,
"classification_loss": total_cls,
"structure_loss": total_struct,
"count_loss": total_count,
"batch_size": valid_samples
}
if return_individual_losses:
result["individual_losses"] = individual
return result
def _empty_loss_dict(self) -> Dict[str, torch.Tensor]:
"""Return empty loss dictionary."""
device = next(self.parameters()).device
return {
"total_loss": torch.tensor(0.0, device=device, requires_grad=True),
"classification_loss": torch.tensor(0.0, device=device),
"structure_loss": torch.tensor(0.0, device=device),
"count_loss": torch.tensor(0.0, device=device),
"batch_size": 0
}
# =========================================================================
# Encoding
# =========================================================================
def _encode_batch(
self,
batch: PreprocessedBatch
) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
"""
Encode batch through transformer and extract embeddings.
Args:
batch: PreprocessedBatch with input_ids and attention_mask
Returns:
- all_token_embs: List of (text_len, hidden) per sample
- all_schema_embs: List of schema embeddings per sample
"""
# Forward through encoder
outputs = self.encoder(
input_ids=batch.input_ids,
attention_mask=batch.attention_mask
)
token_embeddings = outputs.last_hidden_state
# Extract embeddings using processor
return self.processor.extract_embeddings_from_batch(
token_embeddings,
batch.input_ids,
batch
)
# =========================================================================
# Loss Computation
# =========================================================================
def _compute_sample_loss(
self,
token_embeddings: torch.Tensor,
embs_per_schema: List[List[torch.Tensor]],
task_types: List[str],
structure_labels: List[Any],
device: torch.device
) -> Dict[str, torch.Tensor]:
"""
Compute all losses for a single sample.
Args:
token_embeddings: (text_len, hidden) text token embeddings
embs_per_schema: List of schema embeddings
task_types: Task type for each schema
structure_labels: Labels for each schema
device: Computation device
Returns:
Dict with classification, structure, and count losses
"""
cls_loss = torch.tensor(0.0, device=device)
struct_loss = torch.tensor(0.0, device=device)
count_loss = torch.tensor(0.0, device=device)
# Compute span representations if needed
has_span_task = any(t != "classifications" for t in task_types)
span_info = None
if has_span_task and token_embeddings.numel() > 0:
span_info = self.compute_span_rep(token_embeddings)
all_counts = []
all_p_embs = []
for i, task_type in enumerate(task_types):
if not embs_per_schema[i]:
continue
schema_emb = torch.stack(embs_per_schema[i])
if task_type == "classifications":
# Classification loss
cls_embeds = schema_emb[1:] # Skip [P] token
logits = self.classifier(cls_embeds).squeeze(-1)
labels = torch.tensor(structure_labels[i], dtype=torch.float, device=device)
cls_loss = cls_loss + F.binary_cross_entropy_with_logits(
logits, labels, reduction="sum"
)
else:
# Structure loss
structure = structure_labels[i]
if structure[0] == 0:
# No instances to extract
continue
if span_info is not None:
struct_loss = struct_loss + self.compute_struct_loss(
span_info["span_rep"],
schema_emb,
structure,
span_info["span_mask"]
)
# Collect for count loss (skip entities)
if task_type != "entities":
all_counts.append(min(structure[0], 19))
all_p_embs.append(schema_emb[0])
# Count loss
if all_counts and all_p_embs:
counts = torch.tensor(all_counts, dtype=torch.long, device=device)
p_embs = torch.stack(all_p_embs)
count_loss = F.cross_entropy(self.count_pred(p_embs), counts, reduction="sum")
return {
"classification": cls_loss,
"structure": struct_loss,
"count": count_loss
}
# =========================================================================
# Span Representation
# =========================================================================
def compute_span_rep(self, token_embeddings: torch.Tensor) -> Dict[str, Any]:
"""
Compute span representations for token embeddings.
Args:
token_embeddings: (text_len, hidden) token embeddings
Returns:
Dict with span_rep, spans_idx, and span_mask
"""
text_length = len(token_embeddings)
device = token_embeddings.device
spans_idx = []
for i in range(text_length):
for j in range(self.max_width):
if i + j < text_length:
spans_idx.append((i, i + j))
else:
spans_idx.append((-1, -1))
spans_idx = torch.tensor([spans_idx], dtype=torch.long, device=device)
# Mask invalid spans
span_mask = (spans_idx[:, :, 0] == -1) | (spans_idx[:, :, 1] == -1)
# Replace invalid with (0, 0) for safe indexing
safe_spans = torch.where(
span_mask.unsqueeze(-1),
torch.zeros_like(spans_idx),
spans_idx
)
# Compute span representations
span_rep = self.span_rep(
token_embeddings.unsqueeze(0),
safe_spans
).squeeze(0)
return {
"span_rep": span_rep,
"spans_idx": spans_idx,
"span_mask": span_mask
}
def compute_struct_loss(
self,
span_rep: torch.Tensor,
schema_emb: torch.Tensor,
structure: List[Any],
span_mask: torch.Tensor,
masking_rate: float = 0.5
) -> torch.Tensor:
"""
Compute structure extraction loss with negative span masking.
Args:
span_rep: (num_spans, hidden) span representations
schema_emb: (num_fields + 1, hidden) schema embeddings
structure: [count, spans] structure labels
span_mask: (1, num_spans) mask for invalid spans
masking_rate: Probability of masking negative spans
Returns:
Structure loss tensor
"""
gold_count = min(structure[0], 19)
struct_proj = self.count_embed(schema_emb[1:], gold_count)
scores = torch.einsum('lkd,bpd->bplk', span_rep, struct_proj)
# Create label tensor
labs = torch.zeros_like(scores)
for i in range(gold_count):
gold_spans = structure[1][i]
for k, span in enumerate(gold_spans):
if span is None or span == (-1, -1):
continue
if isinstance(span, tuple):
start, end = span
width = end - start
if 0 <= start < scores.shape[2] and 0 <= width < scores.shape[3]:
labs[i, k, start, width] = 1
elif isinstance(span, list):
for sub in span:
if sub is None or sub == (-1, -1):
continue
start, end = sub
width = end - start
if 0 <= start < scores.shape[2] and 0 <= width < scores.shape[3]:
labs[i, k, start, width] = 1
# Apply negative masking
if masking_rate > 0.0 and self.training:
negative = (labs == 0)
random_mask = torch.rand_like(scores) < masking_rate
to_mask = negative & random_mask
loss_mask = (~to_mask).float()
else:
loss_mask = torch.ones_like(scores)
# Compute masked loss
loss = F.binary_cross_entropy_with_logits(scores, labs, reduction="none")
loss = loss * loss_mask
loss = loss.view(loss.shape[0], loss.shape[1], -1) * (~span_mask[0]).float()
return loss.sum()
# =========================================================================
# Hugging Face Methods
# =========================================================================
def push_to_hub(self, repo_id: str, private: bool = True):
"""Push model to Hugging Face Hub."""
with tempfile.TemporaryDirectory() as tmp_dir:
self.save_pretrained(tmp_dir)
super().push_to_hub(repo_id=repo_id, save_dir=tmp_dir, private=private)
self.processor.tokenizer.push_to_hub(repo_id)
@classmethod
def from_pretrained(cls, repo_or_dir: str, **kwargs):
"""
Load model from Hugging Face Hub or local directory.
To use a LoRA adapter:
1. Load the base model first
2. Then load the adapter using model.load_adapter()
Example:
model = Extractor.from_pretrained("base-model-name")
model.load_adapter("path/to/adapter")
"""
from huggingface_hub import hf_hub_download
def download_or_local(repo, filename):
if os.path.isdir(repo):
return os.path.join(repo, filename)
return hf_hub_download(repo, filename)
config_path = download_or_local(repo_or_dir, "config.json")
config = cls.config_class.from_pretrained(config_path)
encoder_config_path = download_or_local(repo_or_dir, "encoder_config/config.json")
encoder_config = AutoConfig.from_pretrained(encoder_config_path)
tokenizer = AutoTokenizer.from_pretrained(repo_or_dir)
model = cls(config, encoder_config=encoder_config, tokenizer=tokenizer)
# Load weights
try:
model_path = download_or_local(repo_or_dir, "model.safetensors")
state_dict = load_file(model_path)
except Exception:
model_path = download_or_local(repo_or_dir, "pytorch_model.bin")
state_dict = torch.load(model_path, map_location="cpu")
# Handle embedding size mismatch
try:
saved_emb = state_dict["encoder.embeddings.word_embeddings.weight"]
model_emb = model.encoder.embeddings.word_embeddings.weight
if saved_emb.shape[0] != model_emb.shape[0]:
extra = model_emb.shape[0] - saved_emb.shape[0]
state_dict["encoder.embeddings.word_embeddings.weight"] = torch.cat([
saved_emb,
torch.randn(extra, saved_emb.shape[1]) * 0.02
], dim=0)
except KeyError:
pass
model.load_state_dict(state_dict)
return model
def load_adapter(self, adapter_path: str) -> 'Extractor':
"""
Load a LoRA adapter onto this model.
If an adapter is already loaded, it will be unloaded first.
Args:
adapter_path: Path to adapter directory
Returns:
self for method chaining
Example:
model.load_adapter("./legal_adapter")
results = model.extract_entities(text, entities)
"""
from gliner2.training.lora import load_lora_adapter, LoRAAdapterConfig
# Load adapter config
config = LoRAAdapterConfig.load(adapter_path)
self._lora_layers = load_lora_adapter(self, adapter_path, auto_unload=True)
self._adapter_config = config
return self
def unload_adapter(self) -> 'Extractor':
"""
Unload current LoRA adapter, restoring base model.
Returns:
self for method chaining
"""
from gliner2.training.lora import unload_lora_adapter
if self._lora_layers:
unload_lora_adapter(self)
self._lora_layers = {}
self._adapter_config = None
return self
def merge_lora(self) -> 'Extractor':
"""
Merge LoRA weights into base model and remove adapter structure.
After calling this, the model will have standard Linear layers with
merged weights. LoRA adapters are permanently removed.
Returns:
self for method chaining
Raises:
ValueError: If no adapter is loaded
Example:
model.load_adapter("./my_adapter")
model.merge_lora() # Now model has merged weights, no LoRA
model.save_pretrained("./merged_model")
"""
if not self._lora_layers:
raise ValueError("No adapter loaded. Nothing to merge.")
from gliner2.training.lora import merge_lora_weights
merge_lora_weights(self)
self._lora_layers = {}
self._adapter_config = None
return self
def save_adapter(self, save_path: str) -> None:
"""
Save only the LoRA adapter (not full model).
Args:
save_path: Directory to save adapter
Raises:
ValueError: If no adapter is loaded
"""
if not self._lora_layers:
raise ValueError("No adapter loaded. Use save_pretrained for full model.")
from gliner2.training.lora import save_lora_adapter
save_lora_adapter(self, save_path)
@property
def has_adapter(self) -> bool:
"""Check if an adapter is currently loaded."""
return bool(self._lora_layers)
@property
def adapter_config(self):
"""Get config of loaded adapter, or None."""
return self._adapter_config
def save_pretrained(
self,
save_directory: str,
save_adapter_only: bool = False,
merge_lora: bool = True,
**kwargs
):
"""
Save model to directory.
Args:
save_directory: Where to save
save_adapter_only: If True and adapter loaded, save only adapter
merge_lora: If True and LoRA active, merge LoRA weights into base
model and remove adapter structure before saving.
WARNING: This permanently removes LoRA from the model instance.
"""
if save_adapter_only:
if not self._lora_layers:
raise ValueError("save_adapter_only=True but no adapter loaded")
self.save_adapter(save_directory)
return
# Handle LoRA merging if requested
if merge_lora and self._lora_layers:
self.merge_lora()
# Original save logic
os.makedirs(save_directory, exist_ok=True)
self.config.save_pretrained(save_directory)
encoder_config_path = os.path.join(save_directory, "encoder_config")
os.makedirs(encoder_config_path, exist_ok=True)
self.encoder.config.save_pretrained(encoder_config_path)
model_path = os.path.join(save_directory, "model.safetensors")
save_file(self.state_dict(), model_path)
self.processor.tokenizer.save_pretrained(save_directory)

View File

@ -0,0 +1,322 @@
"""
GLiNER2 Trainer with Optimized DataLoader-based Preprocessing
This module provides training utilities that leverage parallel preprocessing
via DataLoader workers for maximum GPU utilization.
"""
import json
import random
from typing import Union, List
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import Trainer, TrainingArguments
from gliner2.processor import SchemaTransformer, PreprocessedBatch, SamplingConfig
# =============================================================================
# Dataset
# =============================================================================
class ExtractorDataset(Dataset):
"""
Dataset for GLiNER2 training.
Returns (text, schema) tuples that are processed by the collate function.
Args:
data_paths: Path or list of paths to JSONL training files
shuffle: Whether to shuffle data on load (default: True)
JSONL Format:
{"input": "text here", "output": {"entities": {...}, ...}}
"""
def __init__(self, data_paths: Union[str, List[str]], shuffle: bool = True):
if isinstance(data_paths, str):
data_paths = [data_paths]
print(f"Loading {len(data_paths)} file(s) for training...")
self.data = []
for path in data_paths:
with open(path, "r", encoding="utf-8") as f:
self.data.extend([json.loads(line) for line in f])
if shuffle:
random.shuffle(self.data)
print(f"Loaded {len(self.data)} records from {len(data_paths)} file(s).")
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> tuple:
"""Return (text, schema) tuple."""
record = self.data[idx]
return record["input"], record["output"]
# =============================================================================
# Data Collator
# =============================================================================
class ExtractorDataCollator:
"""
Data collator that uses processor's collate function.
This enables parallel preprocessing via DataLoader workers.
Args:
processor: SchemaTransformer instance
is_training: Whether in training mode (enables augmentation)
"""
def __init__(self, processor: SchemaTransformer, is_training: bool = True):
self.processor = processor
self.is_training = is_training
def __call__(self, batch: List[tuple]) -> PreprocessedBatch:
"""
Collate batch of (text, schema) tuples into PreprocessedBatch.
Args:
batch: List of (text, schema) tuples from dataset
Returns:
PreprocessedBatch ready for model.forward()
"""
if self.is_training:
return self.processor.collate_fn_train(batch)
else:
return self.processor.collate_fn_inference(batch)
# =============================================================================
# Trainer
# =============================================================================
class ExtractorTrainer(Trainer):
"""
Trainer for GLiNER2 with optimized preprocessing.
Key features:
- Parallel preprocessing via DataLoader workers
- Separate learning rates for encoder and other layers
- Optional classifier-only fine-tuning
- FP16 support
- Gradient accumulation
Example:
>>> processor = SchemaTransformer(model_name, sampling_config=config)
>>> collator = ExtractorDataCollator(processor, is_training=True)
>>>
>>> trainer = ExtractorTrainer(
... model=model,
... args=TrainingArguments(
... output_dir="./output",
... per_device_train_batch_size=32,
... dataloader_num_workers=8, # Parallel preprocessing!
... dataloader_pin_memory=True,
... ),
... train_dataset=dataset,
... data_collator=collator,
... encoder_lr=1e-5,
... custom_lr=5e-4,
... weight_decay=0.01,
... )
>>> trainer.train()
"""
def __init__(
self,
encoder_lr: float = 1e-5,
custom_lr: float = 5e-4,
weight_decay: float = 0.01,
finetune_classifier: bool = False,
**kwargs
):
"""
Initialize trainer.
Args:
encoder_lr: Learning rate for encoder parameters
custom_lr: Learning rate for non-encoder parameters
weight_decay: Weight decay for all parameters
finetune_classifier: If True, freeze all except classifier
**kwargs: Arguments passed to Trainer
"""
self.encoder_lr = encoder_lr
self.custom_lr = custom_lr
self.custom_weight_decay = weight_decay
self.finetune_classifier = finetune_classifier
super().__init__(**kwargs)
if self.finetune_classifier:
self._freeze_non_classifier()
def _freeze_non_classifier(self):
"""Freeze all parameters except classifier."""
print("Finetuning classifier only: freezing all other parameters.")
for name, param in self.model.named_parameters():
if not name.startswith("classifier"):
param.requires_grad = False
def create_optimizer(self):
"""Create optimizer with separate parameter groups."""
if self.finetune_classifier:
# Only classifier parameters
classifier_params = [
p for n, p in self.model.named_parameters()
if n.startswith("classifier") and p.requires_grad
]
if not classifier_params:
raise ValueError("No trainable parameters in classifier.")
groups = [{
"params": classifier_params,
"lr": self.custom_lr,
"weight_decay": self.custom_weight_decay,
}]
else:
# Separate encoder and other parameters
encoder_params = list(self.model.encoder.parameters())
other_params = [
p for n, p in self.model.named_parameters()
if "encoder" not in n and p.requires_grad
]
groups = [
{
"params": encoder_params,
"lr": self.encoder_lr,
"weight_decay": self.custom_weight_decay
},
{
"params": other_params,
"lr": self.custom_lr,
"weight_decay": self.custom_weight_decay
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(groups, **optimizer_kwargs)
def compute_loss(self, model, inputs: PreprocessedBatch, return_outputs: bool = False, **kwargs):
"""
Compute loss on preprocessed batch.
Args:
model: The model
inputs: PreprocessedBatch from collator
return_outputs: Whether to return outputs dict
Returns:
Loss tensor, optionally with outputs dict
"""
# Forward pass - inputs is already PreprocessedBatch
outputs = model(inputs, return_individual_losses=False)
# Handle empty batch
if outputs["batch_size"] == 0:
device = next(model.parameters()).device
loss = torch.tensor(0.0, device=device, requires_grad=True)
else:
loss = outputs["total_loss"]
return (loss, outputs) if return_outputs else loss
# =============================================================================
# Training Utilities
# =============================================================================
def create_training_dataloader(
dataset: ExtractorDataset,
processor: SchemaTransformer,
batch_size: int = 32,
num_workers: int = 8,
pin_memory: bool = True,
shuffle: bool = True,
prefetch_factor: int = 2,
) -> DataLoader:
"""
Create an optimized DataLoader for training.
This function creates a DataLoader configured for maximum preprocessing
efficiency using parallel workers.
Args:
dataset: ExtractorDataset instance
processor: SchemaTransformer for preprocessing
batch_size: Batch size
num_workers: Number of parallel workers for preprocessing
pin_memory: Pin memory for faster GPU transfer
shuffle: Shuffle data each epoch
prefetch_factor: Batches to prefetch per worker
Returns:
Configured DataLoader
Example:
>>> loader = create_training_dataloader(
... dataset=train_dataset,
... processor=processor,
... batch_size=32,
... num_workers=8,
... )
>>> for batch in loader:
... batch = batch.to(device)
... loss = model(batch)["total_loss"]
"""
collator = ExtractorDataCollator(processor, is_training=True)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
prefetch_factor=prefetch_factor if num_workers > 0 else None,
collate_fn=collator,
persistent_workers=num_workers > 0,
)
def create_inference_dataloader(
texts: List[str],
schemas: List[dict],
processor: SchemaTransformer,
batch_size: int = 32,
num_workers: int = 4,
) -> DataLoader:
"""
Create a DataLoader for inference.
Args:
texts: List of input texts
schemas: List of schemas (same length as texts or single schema)
processor: SchemaTransformer for preprocessing
batch_size: Batch size
num_workers: Number of workers
Returns:
DataLoader yielding PreprocessedBatch
"""
# Handle single schema for all texts
if len(schemas) == 1:
schemas = schemas * len(texts)
dataset = list(zip(texts, schemas))
collator = ExtractorDataCollator(processor, is_training=False)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collator,
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,836 @@
"""
Custom LoRA (Low-Rank Adaptation) Implementation for GLiNER2
=============================================================
Parameter-efficient fine-tuning by injecting trainable low-rank matrices
into frozen linear layers of the encoder.
Based on: "LoRA: Low-Rank Adaptation of Large Language Models"
Paper: https://arxiv.org/abs/2106.09685
"""
from __future__ import annotations
import json
import logging
import math
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
from safetensors.torch import save_file, load_file
logger = logging.getLogger(__name__)
# =============================================================================
# LoRA Configuration
# =============================================================================
@dataclass
class LoRAConfig:
"""
Configuration for LoRA parameter-efficient fine-tuning.
Parameters
----------
enabled : bool
Whether LoRA is enabled.
r : int
Rank of the low-rank decomposition (bottleneck dimension).
Higher r = more parameters but better approximation.
Typical values: 4, 8, 16, 32, 64.
alpha : float
Scaling factor for LoRA updates. Final scaling is alpha/r.
Typical values: 8, 16, 32 (often 2*r).
dropout : float
Dropout probability applied to LoRA path.
target_modules : List[str]
Module names to apply LoRA to. Supported module groups:
- "encoder" - Applies LoRA to query, key, value, dense layers within encoder
- "encoder.query" - Only query layers in encoder
- "encoder.key" - Only key layers in encoder
- "encoder.value" - Only value layers in encoder
- "encoder.dense" - Only dense layers in encoder
- "span_rep" - Applies LoRA to ALL linear layers within span_rep
- "classifier" - Applies LoRA to ALL linear layers within classifier
- "count_embed" - Applies LoRA to ALL linear layers within count_embed
- "count_pred" - Applies LoRA to ALL linear layers within count_pred
Examples:
- ["encoder"] - all encoder layers (query, key, value, dense)
- ["encoder.query", "encoder.key", "encoder.value"] - only attention layers
- ["encoder.dense"] - only dense (FFN) layers in encoder
- ["encoder", "span_rep", "classifier"] - encoder + task heads
- ["classifier"] - classifier fine-tuning only
"""
enabled: bool = False
r: int = 8
alpha: float = 16.0
dropout: float = 0.0
target_modules: List[str] = field(default_factory=lambda: ["encoder"])
def __post_init__(self):
if self.r <= 0:
raise ValueError(f"LoRA rank must be > 0, got {self.r}")
if self.alpha <= 0:
raise ValueError(f"LoRA alpha must be > 0, got {self.alpha}")
if not 0 <= self.dropout < 1:
raise ValueError(f"LoRA dropout must be in [0, 1), got {self.dropout}")
if self.enabled and not self.target_modules:
raise ValueError("target_modules cannot be empty when LoRA is enabled")
@dataclass
class LoRAAdapterConfig:
"""
Configuration for a saved LoRA adapter.
This is the config that gets saved with adapter-only checkpoints.
"""
adapter_type: str = "lora"
adapter_version: str = "1.0"
lora_r: int = 8
lora_alpha: float = 16.0
lora_dropout: float = 0.0
target_modules: List[str] = field(default_factory=list)
created_at: str = ""
def save(self, path: Union[str, Path]) -> None:
"""Save adapter config to JSON file."""
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
config_path = path / "adapter_config.json"
# Set created_at if not set
if not self.created_at:
self.created_at = datetime.utcnow().isoformat() + "Z"
with open(config_path, "w") as f:
json.dump(asdict(self), f, indent=2)
logger.info(f"Saved adapter config to {config_path}")
@classmethod
def load(cls, path: Union[str, Path]) -> 'LoRAAdapterConfig':
"""Load adapter config from JSON file or directory."""
path = Path(path)
# If path is a directory, look for adapter_config.json
if path.is_dir():
config_path = path / "adapter_config.json"
else:
config_path = path
if not config_path.exists():
raise FileNotFoundError(f"Adapter config not found at {config_path}")
with open(config_path) as f:
config_dict = json.load(f)
return cls(**config_dict)
@classmethod
def is_adapter_path(cls, path: Union[str, Path]) -> bool:
"""Check if path contains an adapter."""
path = Path(path)
# Check for adapter_config.json
if path.is_dir():
return (path / "adapter_config.json").exists()
else:
return path.name == "adapter_config.json" and path.exists()
# =============================================================================
# LoRA Layer
# =============================================================================
class LoRALayer(nn.Module):
"""
LoRA-enhanced Linear layer.
Computes: output = W*x + (B*A*x) * scaling
Where:
- W is the frozen original weight
- A, B are trainable low-rank matrices
- scaling = alpha / r
Parameters
----------
base_layer : nn.Linear
Original linear layer (will be frozen).
r : int
Rank of low-rank decomposition.
alpha : float
LoRA scaling factor.
dropout : float
Dropout probability.
"""
def __init__(
self,
base_layer: nn.Linear,
r: int,
alpha: float,
dropout: float = 0.0,
):
super().__init__()
self.r = r
self.alpha = alpha
self.scaling = alpha / r
in_features = base_layer.in_features
out_features = base_layer.out_features
# Store frozen base layer
self.base_layer = base_layer
for param in self.base_layer.parameters():
param.requires_grad = False
# Get device from base layer to ensure LoRA parameters are on same device
device = next(base_layer.parameters()).device
# LoRA low-rank matrices
# A: (r, in_features) - initialized with small random values
# B: (out_features, r) - initialized to zero (no change at start)
self.lora_A = nn.Parameter(torch.zeros(r, in_features, device=device))
self.lora_B = nn.Parameter(torch.zeros(out_features, r, device=device))
# Initialize A with Kaiming uniform (same as nn.Linear default)
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
# B stays zero-initialized
# Dropout
self.lora_dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()
# Flag to track if weights are merged
self.merged = False
# Expose base layer attributes for compatibility
@property
def weight(self):
"""Expose weight from base layer for compatibility."""
return self.base_layer.weight
@property
def bias(self):
"""Expose bias from base layer for compatibility."""
return self.base_layer.bias
@property
def in_features(self):
"""Expose in_features from base layer."""
return self.base_layer.in_features
@property
def out_features(self):
"""Expose out_features from base layer."""
return self.base_layer.out_features
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass with LoRA.
Parameters
----------
x : torch.Tensor
Input tensor of shape (..., in_features).
Returns
-------
torch.Tensor
Output tensor of shape (..., out_features).
"""
# Base output from frozen weights
base_output = self.base_layer(x)
if self.merged:
# Weights already merged, just use base layer
return base_output
# LoRA path: x -> dropout -> A -> B -> scale
# Equivalent to: (x @ A.T) @ B.T * scaling
lora_output = self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T
return base_output + lora_output * self.scaling
def merge_weights(self):
"""Merge LoRA weights (B @ A) into base layer weights."""
if self.merged:
# Already merged, silently skip
return
with torch.no_grad():
# Compute LoRA contribution: B @ A * scaling
lora_weight = (self.lora_B @ self.lora_A) * self.scaling
# Add to base weight
self.base_layer.weight.data += lora_weight
self.merged = True
logger.debug(f"Merged LoRA weights (r={self.r}) into base layer")
def unmerge_weights(self):
"""Separate LoRA weights from base layer (reverse of merge)."""
if not self.merged:
# Not merged, silently skip
return
with torch.no_grad():
# Subtract LoRA contribution
lora_weight = (self.lora_B @ self.lora_A) * self.scaling
self.base_layer.weight.data -= lora_weight
self.merged = False
logger.debug(f"Unmerged LoRA weights (r={self.r}) from base layer")
def extra_repr(self) -> str:
return f"r={self.r}, alpha={self.alpha}, scaling={self.scaling:.4f}, merged={self.merged}"
# =============================================================================
# LoRA Application Functions
# =============================================================================
# Module-specific patterns for LoRA application
ENCODER_PATTERNS = ["query", "key", "value", "dense"]
ALL_LINEAR_MODULES = ["span_rep", "classifier", "count_embed", "count_pred"]
def apply_lora_to_model(
model: nn.Module,
config: LoRAConfig,
) -> Tuple[nn.Module, Dict[str, LoRALayer]]:
"""
Apply LoRA to linear layers based on module groups in target_modules.
Module group behavior:
- "encoder": Applies LoRA to query, key, value, dense layers within encoder
- "encoder.query": Only query layers in encoder
- "encoder.key": Only key layers in encoder
- "encoder.value": Only value layers in encoder
- "encoder.dense": Only dense layers in encoder
- "span_rep", "classifier", "count_embed", "count_pred": Applies LoRA to ALL linear layers
Parameters
----------
model : nn.Module
The model to apply LoRA to.
config : LoRAConfig
LoRA configuration.
Returns
-------
model : nn.Module
Modified model with LoRA layers.
lora_layers : Dict[str, LoRALayer]
Dictionary mapping layer names to LoRA layers.
"""
if not config.enabled:
logger.info("LoRA is disabled, skipping application")
return model, {}
lora_layers = {}
def _should_apply_lora(local_name: str, full_path: str) -> bool:
"""
Check if LoRA should be applied based on module groups.
Args:
local_name: Local module name (e.g., "query", "linear")
full_path: Full path from model root (e.g., "encoder.layer.0.attention.self.query")
Returns:
True if LoRA should be applied to this layer
"""
for target in config.target_modules:
if target == "encoder":
# For encoder, apply only to specific patterns
if full_path.startswith("encoder."):
# Check if local name matches encoder patterns
if any(pattern in local_name for pattern in ENCODER_PATTERNS):
return True
elif target.startswith("encoder."):
# Specific encoder layer (e.g., "encoder.query", "encoder.dense")
layer_name = target.split(".", 1)[1] # Extract "query" from "encoder.query"
if full_path.startswith("encoder.") and layer_name in local_name:
return True
elif target in ALL_LINEAR_MODULES:
# For these modules, apply to ALL linear layers within
if full_path.startswith(f"{target}."):
return True
return False
# Recursively find and replace modules
def _inject_lora_recursive(module: nn.Module, prefix: str = ""):
for name, child in module.named_children():
full_name = f"{prefix}.{name}" if prefix else name
# Apply LoRA to matching Linear layers
if isinstance(child, nn.Linear) and _should_apply_lora(name, full_name):
# Replace with LoRA layer
lora_layer = LoRALayer(
base_layer=child,
r=config.r,
alpha=config.alpha,
dropout=config.dropout,
)
setattr(module, name, lora_layer)
lora_layers[full_name] = lora_layer
logger.debug(
f"Applied LoRA to {full_name} "
f"(in={child.in_features}, out={child.out_features})"
)
else:
# Recurse into child
_inject_lora_recursive(child, full_name)
_inject_lora_recursive(model)
if not lora_layers:
logger.warning(
f"No LoRA layers were applied. Target modules {config.target_modules} "
f"not found. Check your target_modules configuration."
)
else:
logger.info(f"Applied LoRA to {len(lora_layers)} layers")
return model, lora_layers
def get_lora_parameters(model: nn.Module) -> List[nn.Parameter]:
"""
Extract all LoRA parameters (lora_A and lora_B) from model.
Parameters
----------
model : nn.Module
Model with LoRA layers.
Returns
-------
List[nn.Parameter]
List of LoRA parameters.
"""
lora_params = []
for module in model.modules():
if isinstance(module, LoRALayer):
lora_params.extend([module.lora_A, module.lora_B])
return lora_params
def get_lora_state_dict(model: nn.Module) -> Dict[str, torch.Tensor]:
"""
Get state dict containing only LoRA parameters.
Parameters
----------
model : nn.Module
Model with LoRA layers.
Returns
-------
Dict[str, torch.Tensor]
State dict with LoRA parameters only.
"""
lora_state = {}
for name, module in model.named_modules():
if isinstance(module, LoRALayer):
lora_state[f"{name}.lora_A"] = module.lora_A.data
lora_state[f"{name}.lora_B"] = module.lora_B.data
return lora_state
def merge_lora_weights(model: nn.Module) -> int:
"""
Merge all LoRA weights into base layers and remove LoRA structure.
After calling this, the model will have standard Linear layers with
merged weights. LoRA adapters are removed from the model.
Parameters
----------
model : nn.Module
Model with LoRA layers.
Returns
-------
int
Number of layers merged and removed.
"""
count = 0
already_merged = 0
for module in model.modules():
if isinstance(module, LoRALayer):
if not module.merged:
module.merge_weights()
count += 1
else:
already_merged += 1
if count > 0:
logger.debug(f"Merged LoRA weights in {count} layers")
if already_merged > 0:
logger.debug(f"Skipped {already_merged} layers (already merged)")
# Remove LoRA layers after merging
if count > 0 or already_merged > 0:
remove_lora_from_model(model)
logger.info(f"Merged and removed LoRA layers from model")
return count
def unmerge_lora_weights(model: nn.Module) -> int:
"""
Unmerge all LoRA weights from their base layers.
Parameters
----------
model : nn.Module
Model with LoRA layers.
Returns
-------
int
Number of layers unmerged.
"""
count = 0
not_merged = 0
for module in model.modules():
if isinstance(module, LoRALayer):
if module.merged:
module.unmerge_weights()
count += 1
else:
not_merged += 1
if count > 0:
logger.debug(f"Unmerged LoRA weights in {count} layers")
if not_merged > 0:
logger.debug(f"Skipped {not_merged} layers (not merged)")
return count
def count_lora_parameters(model: nn.Module) -> Tuple[int, int, float]:
"""
Count LoRA parameters vs total parameters.
Parameters
----------
model : nn.Module
Model with LoRA layers.
Returns
-------
lora_params : int
Number of trainable LoRA parameters.
total_params : int
Total number of model parameters.
percentage : float
Percentage of trainable parameters.
"""
lora_params = sum(p.numel() for p in get_lora_parameters(model))
total_params = sum(p.numel() for p in model.parameters())
percentage = (lora_params / total_params * 100) if total_params > 0 else 0.0
return lora_params, total_params, percentage
def print_lora_info(model: nn.Module, config: LoRAConfig):
"""
Print detailed LoRA configuration and parameter statistics.
Parameters
----------
model : nn.Module
Model with LoRA layers.
config : LoRAConfig
LoRA configuration.
"""
lora_params, total_params, percentage = count_lora_parameters(model)
# Count LoRA layers
num_lora_layers = sum(1 for m in model.modules() if isinstance(m, LoRALayer))
print("=" * 70)
print("🔧 LoRA Configuration")
print("=" * 70)
print(f"Enabled : {config.enabled}")
print(f"Rank (r) : {config.r}")
print(f"Alpha : {config.alpha}")
print(f"Scaling (α/r) : {config.alpha / config.r:.4f}")
print(f"Dropout : {config.dropout}")
print(f"Target modules : {', '.join(config.target_modules)}")
print(f"LoRA layers : {num_lora_layers}")
print("-" * 70)
print(f"Trainable params : {lora_params:,} / {total_params:,} ({percentage:.2f}%)")
print(f"Memory savings : ~{100 - percentage:.1f}% fewer gradients")
print("=" * 70)
def remove_lora_from_model(model: nn.Module) -> nn.Module:
"""
Remove LoRA layers and restore original Linear layers.
Useful for inference with merged weights.
Parameters
----------
model : nn.Module
Model with LoRA layers.
Returns
-------
nn.Module
Model with LoRA layers replaced by standard Linear layers.
"""
def _remove_lora_recursive(module: nn.Module):
for name, child in module.named_children():
if isinstance(child, LoRALayer):
# Ensure weights are merged
if not child.merged:
child.merge_weights()
# Replace LoRALayer with its base layer
setattr(module, name, child.base_layer)
logger.debug(f"Removed LoRA from {name}, restored base layer")
else:
_remove_lora_recursive(child)
_remove_lora_recursive(model)
logger.info("Removed all LoRA layers from model")
return model
# =============================================================================
# Adapter Management Functions
# =============================================================================
def save_lora_adapter(
model: nn.Module,
save_path: Union[str, Path],
) -> None:
"""
Save only LoRA adapter weights and config.
Args:
model: Model with LoRA layers (must NOT be merged)
save_path: Directory to save adapter
Saves:
- adapter_config.json
- adapter_weights.safetensors
"""
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
# Collect LoRA weights and config
lora_state = {}
lora_config = None
for name, module in model.named_modules():
if isinstance(module, LoRALayer):
if module.merged:
raise ValueError(
"Cannot save adapter with merged weights. "
"Call unmerge_lora_weights() first."
)
# Save LoRA matrices with full path from model root
lora_state[f"{name}.lora_A"] = module.lora_A.data
lora_state[f"{name}.lora_B"] = module.lora_B.data
# Extract config from first LoRA layer
if lora_config is None:
lora_config = {
"lora_r": module.r,
"lora_alpha": module.alpha,
"lora_dropout": module.lora_dropout.p if hasattr(module.lora_dropout, 'p') else 0.0,
}
if not lora_state:
raise ValueError("No LoRA layers found in model")
# Save weights
weights_path = save_path / "adapter_weights.safetensors"
save_file(lora_state, str(weights_path))
logger.info(f"Saved {len(lora_state)} LoRA tensors to {weights_path}")
# Determine target modules from layer names
# Extract top-level module names (encoder, span_rep, classifier, etc.)
target_modules = set()
for key in lora_state.keys():
# Extract first level module from full path
# e.g., "encoder.layer.0.attention.self.query.lora_A" -> "encoder"
# e.g., "span_rep.project_start.0.lora_A" -> "span_rep"
parts = key.split(".")
if len(parts) > 0:
# Get the first level module name
module_name = parts[0]
target_modules.add(module_name)
# Create and save adapter config
adapter_config = LoRAAdapterConfig(
adapter_type="lora",
adapter_version="1.0",
lora_r=lora_config["lora_r"],
lora_alpha=lora_config["lora_alpha"],
lora_dropout=lora_config["lora_dropout"],
target_modules=sorted(list(target_modules)),
created_at=datetime.utcnow().isoformat() + "Z"
)
adapter_config.save(save_path)
logger.info(f"Saved LoRA adapter to {save_path}")
def load_lora_adapter(
model: nn.Module,
adapter_path: Union[str, Path],
auto_unload: bool = True,
) -> Dict[str, LoRALayer]:
"""
Load LoRA adapter onto model.
Args:
model: Base model (should not have LoRA applied)
adapter_path: Path to adapter directory
auto_unload: If True, unload existing adapter first
Returns:
Dict of LoRA layers that were applied
"""
adapter_path = Path(adapter_path)
# Load adapter config
adapter_config = LoRAAdapterConfig.load(adapter_path)
# Unload existing adapter if requested
if auto_unload and has_lora_adapter(model):
logger.info("Unloading existing adapter before loading new one")
unload_lora_adapter(model)
# Load adapter weights
weights_path = adapter_path / "adapter_weights.safetensors"
if not weights_path.exists():
raise FileNotFoundError(f"Adapter weights not found at {weights_path}")
lora_state = load_file(str(weights_path))
logger.info(f"Loaded {len(lora_state)} LoRA tensors from {weights_path}")
# Apply LoRA to matching layers
lora_config = LoRAConfig(
enabled=True,
r=adapter_config.lora_r,
alpha=adapter_config.lora_alpha,
dropout=adapter_config.lora_dropout,
target_modules=adapter_config.target_modules,
)
model, lora_layers = apply_lora_to_model(model, lora_config)
# Load saved weights into LoRA layers
for name, module in model.named_modules():
if isinstance(module, LoRALayer):
lora_a_key = f"{name}.lora_A"
lora_b_key = f"{name}.lora_B"
if lora_a_key in lora_state and lora_b_key in lora_state:
# Move loaded tensors to the same device as the module
device = next(module.parameters()).device
module.lora_A.data = lora_state[lora_a_key].to(device)
module.lora_B.data = lora_state[lora_b_key].to(device)
logger.debug(f"Loaded weights for {name}")
else:
logger.warning(f"No saved weights found for {name}")
logger.info(f"Loaded LoRA adapter from {adapter_path}")
return lora_layers
def unload_lora_adapter(model: nn.Module) -> int:
"""
Remove all LoRA layers, restoring original Linear layers.
Unlike remove_lora_from_model, this does NOT merge weights.
Just removes LoRA layers entirely.
Returns:
Number of layers unloaded
"""
count = 0
def _get_parent_module(model: nn.Module, full_name: str) -> Tuple[nn.Module, str]:
"""Get parent module and child name from full module path."""
parts = full_name.split('.')
parent = model
for part in parts[:-1]:
parent = getattr(parent, part)
return parent, parts[-1]
# Collect all LoRA layers first (to avoid modifying dict during iteration)
lora_layers = []
for name, module in model.named_modules():
if isinstance(module, LoRALayer):
lora_layers.append((name, module))
# Remove LoRA layers
for name, lora_layer in lora_layers:
parent, child_name = _get_parent_module(model, name)
# Replace with original base_layer (no merge)
setattr(parent, child_name, lora_layer.base_layer)
count += 1
logger.debug(f"Unloaded LoRA from {name}")
if count > 0:
logger.info(f"Unloaded {count} LoRA layers")
return count
def has_lora_adapter(model: nn.Module) -> bool:
"""Check if model has LoRA layers applied."""
for module in model.modules():
if isinstance(module, LoRALayer):
return True
return False
def get_adapter_config(model: nn.Module) -> Optional[LoRAAdapterConfig]:
"""
Get config of currently loaded adapter, if any.
Note: This reconstructs config from LoRA layers.
The actual adapter config is stored in model._adapter_config
when loaded via model.load_adapter().
"""
if not has_lora_adapter(model):
return None
# Extract config from first LoRA layer
for module in model.modules():
if isinstance(module, LoRALayer):
target_modules = set()
# Collect all target module groups (top-level modules)
for name, m in model.named_modules():
if isinstance(m, LoRALayer):
# Extract first level module name
parts = name.split(".")
if parts:
target_modules.add(parts[0])
return LoRAAdapterConfig(
adapter_type="lora",
adapter_version="1.0",
lora_r=module.r,
lora_alpha=module.alpha,
lora_dropout=module.lora_dropout.p if hasattr(module.lora_dropout, 'p') else 0.0,
target_modules=sorted(list(target_modules)),
created_at=""
)
return None

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,25 @@
[build-system]
requires = ["setuptools>=61.0.0"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
include = ["gliner2", "gliner2.*"]
[tool.setuptools.dynamic]
version = {attr = "gliner2.__version__"}
[project]
name = "gliner2"
readme = "README.md"
requires-python = ">=3.8"
maintainers = [
{name = "Urchade Zaratiana"},
]
dependencies = [
"gliner",
"pydantic>=2.0.0",
]
dynamic = ["version"]

View File

@ -0,0 +1,663 @@
# GLiNER2 Classification Tutorial
This tutorial covers all the ways to perform text classification with GLiNER2, from simple single-label classification to complex multi-label tasks with custom configurations.
## Table of Contents
- [Setup](#setup)
- [Single-Label Classification](#single-label-classification)
- [Multi-Label Classification](#multi-label-classification)
- [Classification with Descriptions](#classification-with-descriptions)
- [Using the Quick API](#using-the-quick-api)
- [Multiple Classification Tasks](#multiple-classification-tasks)
- [Advanced Configurations](#advanced-configurations)
- [Best Practices](#best-practices)
## Setup
```python
from gliner2 import GLiNER2
# Load the pre-trained model
extractor = GLiNER2.from_pretrained("your-model-name")
```
## Single-Label Classification
The simplest form - classify text into one of several categories.
### Basic Example
```python
# Define the schema
schema = extractor.create_schema().classification(
"sentiment",
["positive", "negative", "neutral"]
)
# Extract
text = "This product exceeded my expectations! Absolutely love it."
results = extractor.extract(text, schema)
print(results)
# Expected output: {'sentiment': 'positive'}
```
### With Confidence Scores
```python
# Same schema as above
schema = extractor.create_schema().classification(
"sentiment",
["positive", "negative", "neutral"]
)
text = "The service was okay, nothing special but not bad either."
results = extractor.extract(text, schema, include_confidence=True)
print(results)
# Expected output: {'sentiment': {'label': 'neutral', 'confidence': 0.82}}
```
## Multi-Label Classification
When text can belong to multiple categories simultaneously.
```python
# Multi-label classification
schema = extractor.create_schema().classification(
"topics",
["technology", "business", "health", "politics", "sports"],
multi_label=True,
cls_threshold=0.3 # Lower threshold for multi-label
)
text = "Apple announced new health monitoring features in their latest smartwatch, boosting their stock price."
results = extractor.extract(text, schema)
print(results)
# Expected output: {'topics': ['technology', 'business', 'health']}
# With confidence scores
results = extractor.extract(text, schema, include_confidence=True)
print(results)
# Expected output: {'topics': [
# {'label': 'technology', 'confidence': 0.92},
# {'label': 'business', 'confidence': 0.78},
# {'label': 'health', 'confidence': 0.65}
# ]}
```
## Classification with Descriptions
Adding descriptions significantly improves accuracy by providing context.
```python
# With label descriptions
schema = extractor.create_schema().classification(
"document_type",
{
"invoice": "A bill for goods or services with payment details",
"receipt": "Proof of payment for a completed transaction",
"contract": "Legal agreement between parties with terms and conditions",
"proposal": "Document outlining suggested plans or services with pricing"
}
)
text = "Please find attached the itemized bill for consulting services rendered in Q3 2024. Payment is due within 30 days."
results = extractor.extract(text, schema)
print(results)
# Expected output: {'document_type': 'invoice'}
# Another example
text2 = "Thank you for your payment of $500. This confirms your transaction was completed on March 1st, 2024."
results2 = extractor.extract(text2, schema)
print(results2)
# Expected output: {'document_type': 'receipt'}
```
## Using the Quick API
For simple classification tasks without building a schema.
### Single Task
```python
text = "The new AI model shows remarkable performance improvements."
results = extractor.classify_text(
text,
{"sentiment": ["positive", "negative", "neutral"]}
)
print(results)
# Expected output: {'sentiment': 'positive'}
# Another example
text2 = "The software keeps crashing and customer support is unresponsive."
results2 = extractor.classify_text(
text2,
{"sentiment": ["positive", "negative", "neutral"]}
)
print(results2)
# Expected output: {'sentiment': 'negative'}
```
### Multiple Tasks
```python
text = "Breaking: Tech giant announces major layoffs amid market downturn"
results = extractor.classify_text(
text,
{
"sentiment": ["positive", "negative", "neutral"],
"urgency": ["high", "medium", "low"],
"category": {
"labels": ["tech", "finance", "politics", "sports"],
"multi_label": False
}
}
)
print(results)
# Expected output: {
# 'sentiment': 'negative',
# 'urgency': 'high',
# 'category': 'tech'
# }
```
### Multi-Label with Config
```python
text = "The smartphone features an amazing camera but disappointing battery life and overheats frequently."
results = extractor.classify_text(
text,
{
"product_aspects": {
"labels": ["camera", "battery", "display", "performance", "design", "heating"],
"multi_label": True,
"cls_threshold": 0.4
}
}
)
print(results)
# Expected output: {'product_aspects': ['camera', 'battery', 'heating']}
# Another example
text2 = "Beautiful design with vibrant display, though the camera could be better."
results2 = extractor.classify_text(
text2,
{
"product_aspects": {
"labels": ["camera", "battery", "display", "performance", "design", "heating"],
"multi_label": True,
"cls_threshold": 0.4
}
}
)
print(results2)
# Expected output: {'product_aspects': ['design', 'display', 'camera']}
```
## Multiple Classification Tasks
You can include multiple classification tasks in a single schema for comprehensive text analysis.
### Basic Multiple Classifications
```python
# Multiple independent classifications
schema = (extractor.create_schema()
.classification("sentiment", ["positive", "negative", "neutral"])
.classification("language", ["english", "spanish", "french", "german", "other"])
.classification("formality", ["formal", "informal", "semi-formal"])
.classification("intent", ["question", "statement", "request", "complaint"])
)
text = "Could you please help me with my order? The service has been disappointing."
results = extractor.extract(text, schema)
print(results)
# Expected output: {
# 'sentiment': 'negative',
# 'language': 'english',
# 'formality': 'formal',
# 'intent': 'question'
# }
# Another example
text2 = "Hey! Just wanted to say your product rocks! 🎉"
results2 = extractor.extract(text2, schema)
print(results2)
# Expected output: {
# 'sentiment': 'positive',
# 'language': 'english',
# 'formality': 'informal',
# 'intent': 'statement'
# }
```
### Mixed Single and Multi-Label Classifications
```python
# Combine different classification types
schema = (extractor.create_schema()
# Single-label classifications
.classification("primary_topic", ["tech", "business", "health", "sports", "politics"])
.classification("urgency", ["immediate", "soon", "later", "not_urgent"])
# Multi-label classifications
.classification("emotions",
["happy", "sad", "angry", "surprised", "fearful", "disgusted"],
multi_label=True,
cls_threshold=0.4
)
.classification("content_flags",
["inappropriate", "spam", "promotional", "personal_info", "financial_info"],
multi_label=True,
cls_threshold=0.3
)
)
text = "URGENT: I'm thrilled to announce our new product! But concerned about competitor reactions. Please keep confidential."
results = extractor.extract(text, schema)
print(results)
# Expected output: {
# 'primary_topic': 'business',
# 'urgency': 'immediate',
# 'emotions': ['happy', 'fearful'],
# 'content_flags': ['promotional', 'personal_info']
# }
# Another example
text2 = "Just saw the game - absolutely devastated by the loss. Can't believe the referee's terrible decision!"
results2 = extractor.extract(text2, schema)
print(results2)
# Expected output: {
# 'primary_topic': 'sports',
# 'urgency': 'not_urgent',
# 'emotions': ['sad', 'angry'],
# 'content_flags': []
# }
```
### Domain-Specific Multiple Classifications
```python
# Customer support ticket classification
support_schema = (extractor.create_schema()
.classification("ticket_type",
["technical_issue", "billing", "feature_request", "bug_report", "other"])
.classification("priority",
["critical", "high", "medium", "low"],
cls_threshold=0.7
)
.classification("product_area",
{
"authentication": "Login, passwords, security",
"payment": "Payment processing, subscriptions",
"ui": "User interface, design issues",
"performance": "Speed, loading, responsiveness",
"data": "Data loss, corruption, sync issues"
},
multi_label=True,
cls_threshold=0.5
)
.classification("customer_sentiment",
["very_satisfied", "satisfied", "neutral", "frustrated", "very_frustrated"],
cls_threshold=0.6
)
.classification("requires_action",
["immediate_response", "investigation_needed", "waiting_customer", "resolved"],
multi_label=True
)
)
ticket_text = """
Subject: Cannot login - Urgent!
I've been trying to login for the past hour but keep getting error messages.
This is critical as I need to process payments for my customers today.
The page just keeps spinning and then times out. I'm extremely frustrated
as this is costing me business. Please fix this immediately!
"""
results = extractor.extract(ticket_text, support_schema)
print(results)
# Expected output: {
# 'ticket_type': 'technical_issue',
# 'priority': 'critical',
# 'product_area': ['authentication', 'payment', 'performance'],
# 'customer_sentiment': 'very_frustrated',
# 'requires_action': ['immediate_response', 'investigation_needed']
# }
# Another support ticket example
ticket_text2 = """
Hi team,
Thanks for the great product! I was wondering if you could add a dark mode feature?
It would really help with eye strain during late night work sessions.
Best regards,
Happy Customer
"""
results2 = extractor.extract(ticket_text2, support_schema)
print(results2)
# Expected output: {
# 'ticket_type': 'feature_request',
# 'priority': 'low',
# 'product_area': ['ui'],
# 'customer_sentiment': 'satisfied',
# 'requires_action': ['waiting_customer']
# }
```
### Sequential Classification with Dependencies
```python
# Email routing and handling classification
email_schema = (extractor.create_schema()
# Primary classification
.classification("email_category",
["sales", "support", "hr", "legal", "general"],
cls_threshold=0.6
)
# Secondary classifications based on context
.classification("sales_stage",
["lead", "qualified", "proposal", "negotiation", "closed"],
cls_threshold=0.5
)
.classification("support_type",
["pre_sales", "technical", "account", "billing"],
cls_threshold=0.5
)
# Action classifications
.classification("required_action",
["reply_needed", "forward_to_team", "schedule_meeting", "no_action"],
multi_label=True,
cls_threshold=0.4
)
.classification("response_timeframe",
["within_1_hour", "within_24_hours", "within_week", "non_urgent"],
cls_threshold=0.6
)
)
email = """
Hi Sales Team,
I'm interested in your enterprise solution. We're currently evaluating vendors
for our upcoming project. Could we schedule a demo next week? We need to make
a decision by month end.
Best regards,
John from TechCorp
"""
results = extractor.extract(email, email_schema)
print(results)
# Expected output: {
# 'email_category': 'sales',
# 'sales_stage': 'qualified',
# 'support_type': 'pre_sales',
# 'required_action': ['reply_needed', 'schedule_meeting'],
# 'response_timeframe': 'within_24_hours'
# }
# HR email example
email2 = """
Dear HR Department,
I need to update my tax withholding information. Could someone please send me
the necessary forms? This is somewhat urgent as I need this changed before the
next payroll cycle.
Thank you,
Sarah
"""
results2 = extractor.extract(email2, email_schema)
print(results2)
# Expected output: {
# 'email_category': 'hr',
# 'sales_stage': 'lead', # May have noise in non-sales emails
# 'support_type': 'account',
# 'required_action': ['reply_needed'],
# 'response_timeframe': 'within_24_hours'
# }
```
### Complex Analysis with Multiple Classifications
```python
# Content moderation and analysis
content_schema = (extractor.create_schema()
# Content classifications
.classification("content_type",
["article", "comment", "review", "social_post", "message"])
.classification("primary_language",
["english", "spanish", "french", "other"])
# Quality assessments
.classification("quality_score",
["excellent", "good", "average", "poor", "spam"],
cls_threshold=0.7
)
.classification("originality",
["original", "derivative", "duplicate", "plagiarized"],
cls_threshold=0.8
)
# Safety and compliance
.classification("safety_flags",
{
"hate_speech": "Contains discriminatory or hateful content",
"violence": "Contains violent or threatening content",
"adult": "Contains adult or explicit content",
"misinformation": "Contains potentially false information",
"personal_info": "Contains personal identifying information"
},
multi_label=True,
cls_threshold=0.3
)
# Engagement predictions
.classification("engagement_potential",
["viral", "high", "medium", "low"],
cls_threshold=0.6
)
.classification("audience_fit",
["general", "professional", "academic", "youth", "senior"],
multi_label=True,
cls_threshold=0.5
)
)
content_text = """
Just discovered this amazing productivity hack that doubled my output!
Here's what I do: I wake up at 5 AM, meditate for 20 minutes, then work
in 90-minute focused blocks. The results have been incredible. My email
is john.doe@example.com if you want more tips!
"""
results = extractor.extract(content_text, content_schema)
print(results)
# Expected output: {
# 'content_type': 'social_post',
# 'primary_language': 'english',
# 'quality_score': 'good',
# 'originality': 'original',
# 'safety_flags': ['personal_info'],
# 'engagement_potential': 'high',
# 'audience_fit': ['general', 'professional']
# }
# Review example
review_text = """
Worst product ever!!! Total scam! Don't buy this garbage. The company should
be shut down for selling this junk. I'm going to report them to authorities.
"""
results2 = extractor.extract(review_text, content_schema)
print(results2)
# Expected output: {
# 'content_type': 'review',
# 'primary_language': 'english',
# 'quality_score': 'poor',
# 'originality': 'original',
# 'safety_flags': ['violence'], # Due to aggressive language
# 'engagement_potential': 'low',
# 'audience_fit': ['general']
# }
```
## Advanced Configurations
### Custom Thresholds
```python
# High-precision classification
schema = extractor.create_schema().classification(
"is_spam",
["spam", "not_spam"],
cls_threshold=0.9 # Very high confidence required
)
text = "Congratulations! You've won $1,000,000! Click here to claim your prize now!"
results = extractor.extract(text, schema)
print(results)
# Expected output: {'is_spam': 'spam'}
# Different thresholds for different tasks
schema = (extractor.create_schema()
.classification("priority", ["urgent", "high", "normal", "low"], cls_threshold=0.8)
.classification("department", ["sales", "support", "billing", "other"], cls_threshold=0.5)
)
text = "URGENT: Customer threatening to cancel $50k contract due to billing error"
results = extractor.extract(text, schema)
print(results)
# Expected output: {
# 'priority': 'urgent',
# 'department': 'billing'
# }
```
### Custom Activation Functions
```python
# Force specific activation
schema = extractor.create_schema().classification(
"category",
["A", "B", "C", "D"],
class_act="softmax" # Options: "sigmoid", "softmax", "auto"
)
text = "This clearly belongs to category B based on the criteria."
results = extractor.extract(text, schema)
print(results)
# Expected output: {'category': 'B'}
```
### Complex Multi-Label Example
```python
# Email classification system
schema = extractor.create_schema().classification(
"email_tags",
{
"action_required": "Email requires recipient to take action",
"meeting_request": "Email contains meeting invitation or scheduling",
"project_update": "Email contains project status or updates",
"urgent": "Email marked as urgent or time-sensitive",
"question": "Email contains questions requiring answers",
"fyi": "Informational email requiring no action"
},
multi_label=True,
cls_threshold=0.35
)
email_text = """
Hi team,
Quick update on Project Alpha: We're ahead of schedule!
However, I need your input on the design mockups by EOD tomorrow.
Can we schedule a 30-min call this week to discuss?
This is quite urgent as the client is waiting.
Best,
Sarah
"""
results = extractor.extract(email_text, schema)
print(results)
# Expected output: {
# 'email_tags': ['action_required', 'meeting_request', 'project_update', 'urgent', 'question']
# }
# FYI email example
email_text2 = """
Team,
Just wanted to let everyone know that I'll be out of office next Monday for a
doctor's appointment. I'll be back Tuesday morning.
Thanks,
Mark
"""
results2 = extractor.extract(email_text2, schema)
print(results2)
# Expected output: {
# 'email_tags': ['fyi']
# }
```
## Best Practices
1. **Use Descriptions**: Always provide label descriptions when possible
```python
# Good - with descriptions
schema = extractor.create_schema().classification(
"intent",
{
"purchase": "User wants to buy a product",
"return": "User wants to return a product",
"inquiry": "User asking for information"
}
)
# Less effective - no context
schema = extractor.create_schema().classification(
"intent",
["purchase", "return", "inquiry"]
)
```
2. **Adjust Thresholds**: Lower thresholds for multi-label (0.3-0.5), higher for single-label (0.5-0.7)
3. **Multi-Label Strategy**: Use multi-label when categories aren't mutually exclusive
```python
# Good use of multi-label
schema = extractor.create_schema().classification(
"product_features",
["waterproof", "wireless", "rechargeable", "portable"],
multi_label=True
)
# Should be single-label
schema = extractor.create_schema().classification(
"size",
["small", "medium", "large"],
multi_label=False # Sizes are mutually exclusive
)
```
4. **Test with Real Examples**: Always test with actual text samples from your domain
## Common Use Cases
- **Sentiment Analysis**: Customer feedback, reviews, social media
- **Intent Classification**: Chatbots, customer service routing
- **Document Classification**: Email filtering, document management
- **Content Moderation**: Toxic content, spam detection
- **Topic Classification**: News categorization, content tagging

View File

@ -0,0 +1,973 @@
# Tutorial 10: LoRA Adapters - Multi-Domain Inference
## Table of Contents
1. [Introduction](#introduction)
2. [Why Use LoRA Adapters?](#why-use-lora-adapters)
3. [Training Your First Adapter](#training-your-first-adapter)
4. [Training Multiple Domain Adapters](#training-multiple-domain-adapters)
5. [Loading and Swapping Adapters](#loading-and-swapping-adapters)
6. [Real-World Use Cases](#real-world-use-cases)
7. [Best Practices](#best-practices)
8. [Troubleshooting](#troubleshooting)
## Introduction
LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning technique that allows you to train specialized adapters for different domains without modifying the base model. This enables:
- **Fast domain switching**: Swap between domains in milliseconds
- **Minimal storage**: Adapters are ~2-10 MB vs ~100-500 MB for full models
- **Domain specialization**: Train separate adapters for legal, medical, financial, etc.
- **Easy deployment**: Keep one base model + multiple lightweight adapters
## Why Use LoRA Adapters?
### Memory Efficiency
```
Full Model Fine-tuning:
- Legal model: 450 MB
- Medical model: 450 MB
- Financial model: 450 MB
Total: 1.35 GB
LoRA Adapters:
- Base model: 450 MB
- Legal adapter: 5 MB
- Medical adapter: 5 MB
- Financial adapter: 5 MB
Total: 465 MB (65% less!)
```
### Fast Training
LoRA adapters train **2-3x faster** than full fine-tuning because:
- Only ~1-5% of parameters are trainable
- Smaller gradient computations
- Less GPU memory required
### Easy Multi-Domain Inference
```python
# One base model, multiple domains
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
# Legal domain
model.load_adapter("./legal_adapter")
legal_results = model.extract_entities(legal_text, ["company", "law"])
# Medical domain (swap in <1 second)
model.load_adapter("./medical_adapter")
medical_results = model.extract_entities(medical_text, ["disease", "drug"])
```
## Training Your First Adapter
### Step 1: Prepare Domain-Specific Data
```python
from gliner2.training.data import InputExample
# Legal domain examples
legal_examples = [
InputExample(
text="Apple Inc. filed a lawsuit against Samsung Electronics.",
entities={"company": ["Apple Inc.", "Samsung Electronics"]}
),
InputExample(
text="The plaintiff Google LLC accused Microsoft Corporation of patent infringement.",
entities={"company": ["Google LLC", "Microsoft Corporation"]}
),
InputExample(
text="Tesla Motors settled the case with the Securities and Exchange Commission.",
entities={
"company": ["Tesla Motors"],
"organization": ["Securities and Exchange Commission"]
}
),
# Add 100-1000+ examples for best results
]
```
### Step 2: Configure LoRA Training
```python
from gliner2 import GLiNER2
from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig
# LoRA configuration
config = TrainingConfig(
output_dir="./legal_adapter",
experiment_name="legal_domain",
# Training parameters
num_epochs=10,
batch_size=8,
gradient_accumulation_steps=2,
encoder_lr=1e-5,
task_lr=5e-4,
# LoRA settings
use_lora=True, # Enable LoRA
lora_r=8, # Rank (4, 8, 16, 32)
lora_alpha=16.0, # Scaling factor (usually 2*r)
lora_dropout=0.0, # Dropout for LoRA layers
lora_target_modules=["encoder"], # Apply to all encoder layers (query, key, value, dense)
save_adapter_only=True, # Save only adapter (not full model)
# Optimization
eval_strategy="epoch", # Evaluates and saves at end of each epoch
eval_steps=500, # Used when eval_strategy="steps"
logging_steps=50,
fp16=True, # Use mixed precision if GPU available
)
```
### Step 3: Train the Adapter
```python
# Load base model
base_model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
# Create trainer
trainer = GLiNER2Trainer(model=base_model, config=config)
# Train adapter
trainer.train(train_data=legal_examples)
# Adapter automatically saved to ./legal_adapter/final/
```
**Training output:**
```
🔧 LoRA Configuration
======================================================================
Enabled : True
Rank (r) : 8
Alpha : 16.0
Scaling (α/r) : 2.0000
Dropout : 0.0
Target modules : query, key, value, dense
LoRA layers : 144
----------------------------------------------------------------------
Trainable params : 1,327,104 / 124,442,368 (1.07%)
Memory savings : ~98.9% fewer gradients
======================================================================
***** Running Training *****
Num examples = 1000
Num epochs = 10
Batch size = 8
Effective batch size = 16
Total optimization steps = 625
LoRA enabled: 1,327,104 trainable / 124,442,368 total (1.07%)
```
## Training Multiple Domain Adapters
Let's train adapters for three different domains: **Legal**, **Medical**, and **Customer Support**.
### Complete Multi-Domain Training Script
```python
from gliner2 import GLiNER2
from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig
from gliner2.training.data import InputExample
# ============================================================================
# Define Domain Data
# ============================================================================
# Legal domain
legal_examples = [
InputExample(
text="Apple Inc. filed a lawsuit against Samsung Electronics.",
entities={"company": ["Apple Inc.", "Samsung Electronics"]}
),
InputExample(
text="The plaintiff Google LLC accused Microsoft Corporation of patent infringement.",
entities={"company": ["Google LLC", "Microsoft Corporation"]}
),
# Add more examples...
]
# Medical domain
medical_examples = [
InputExample(
text="Patient diagnosed with Type 2 Diabetes and Hypertension.",
entities={"disease": ["Type 2 Diabetes", "Hypertension"]}
),
InputExample(
text="Prescribed Metformin 500mg twice daily and Lisinopril 10mg once daily.",
entities={
"drug": ["Metformin", "Lisinopril"],
"dosage": ["500mg", "10mg"]
}
),
# Add more examples...
]
# Customer support domain
support_examples = [
InputExample(
text="Customer John Smith reported issue with Order #12345.",
entities={
"customer": ["John Smith"],
"order_id": ["Order #12345"]
}
),
InputExample(
text="Refund of $99.99 processed for Order #98765 on 2024-01-15.",
entities={
"order_id": ["Order #98765"],
"amount": ["$99.99"],
"date": ["2024-01-15"]
}
),
# Add more examples...
]
# ============================================================================
# Training Function
# ============================================================================
def train_domain_adapter(
base_model_name: str,
examples: list,
domain_name: str,
output_dir: str = "./adapters"
):
"""Train a LoRA adapter for a specific domain."""
adapter_path = f"{output_dir}/{domain_name}_adapter"
config = TrainingConfig(
output_dir=adapter_path,
experiment_name=f"{domain_name}_domain",
# Training
num_epochs=10,
batch_size=8,
gradient_accumulation_steps=2,
encoder_lr=1e-5,
task_lr=5e-4,
# LoRA
use_lora=True,
lora_r=8,
lora_alpha=16.0,
lora_dropout=0.0,
lora_target_modules=["encoder"], # All encoder layers
save_adapter_only=True,
# Logging & Checkpointing
eval_strategy="no", # Set to "epoch" or "steps" if you have validation set
eval_steps=500, # Used when eval_strategy="steps"
logging_steps=50,
fp16=True,
)
# Load base model
print(f"\n{'='*60}")
print(f"Training {domain_name.upper()} adapter")
print(f"{'='*60}")
model = GLiNER2.from_pretrained(base_model_name)
trainer = GLiNER2Trainer(model=model, config=config)
# Train
results = trainer.train(train_data=examples)
print(f"\n✅ {domain_name.capitalize()} adapter trained!")
print(f"📁 Saved to: {adapter_path}/final/")
print(f"⏱️ Training time: {results['total_time_seconds']:.2f}s")
return f"{adapter_path}/final"
# ============================================================================
# Train All Adapters
# ============================================================================
if __name__ == "__main__":
BASE_MODEL = "fastino/gliner2-base-v1"
# Train adapters for each domain
legal_adapter_path = train_domain_adapter(
BASE_MODEL, legal_examples, "legal"
)
medical_adapter_path = train_domain_adapter(
BASE_MODEL, medical_examples, "medical"
)
support_adapter_path = train_domain_adapter(
BASE_MODEL, support_examples, "support"
)
print("\n" + "="*60)
print("🎉 All adapters trained successfully!")
print("="*60)
print(f"Legal adapter: {legal_adapter_path}")
print(f"Medical adapter: {medical_adapter_path}")
print(f"Support adapter: {support_adapter_path}")
```
## Loading and Swapping Adapters
### Basic Usage
```python
from gliner2 import GLiNER2
# Load base model once
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
# Load legal adapter
model.load_adapter("./adapters/legal_adapter/final")
# Use the model
result = model.extract_entities(
"Apple Inc. sued Samsung over patent rights.",
["company", "legal_action"]
)
print(result)
```
### Swapping Between Adapters
```python
# Load base model
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
# Legal domain
print("📋 Legal Analysis:")
model.load_adapter("./adapters/legal_adapter/final")
legal_text = "Google LLC filed a complaint against Oracle Corporation."
legal_result = model.extract_entities(legal_text, ["company", "legal_action"])
print(f" {legal_result}")
# Swap to medical domain
print("\n🏥 Medical Analysis:")
model.load_adapter("./adapters/medical_adapter/final")
medical_text = "Patient presents with Pneumonia and was prescribed Amoxicillin."
medical_result = model.extract_entities(medical_text, ["disease", "drug"])
print(f" {medical_result}")
# Swap to support domain
print("\n💬 Support Analysis:")
model.load_adapter("./adapters/support_adapter/final")
support_text = "Customer reported Order #12345 not delivered on time."
support_result = model.extract_entities(support_text, ["order_id", "issue"])
print(f" {support_result}")
# Use base model without adapter
print("\n🔧 Base Model (no adapter):")
model.unload_adapter()
base_result = model.extract_entities("Some generic text", ["entity"])
print(f" {base_result}")
```
**Output:**
```
📋 Legal Analysis:
{'entities': [{'text': 'Google LLC', 'label': 'company', ...},
{'text': 'Oracle Corporation', 'label': 'company', ...}]}
🏥 Medical Analysis:
{'entities': [{'text': 'Pneumonia', 'label': 'disease', ...},
{'text': 'Amoxicillin', 'label': 'drug', ...}]}
💬 Support Analysis:
{'entities': [{'text': 'Order #12345', 'label': 'order_id', ...}]}
🔧 Base Model (no adapter):
{'entities': [{'text': 'text', 'label': 'entity', ...}]}
```
### Batch Processing with Adapter Swapping
```python
def process_documents_by_domain(model, documents_by_domain, adapters):
"""
Process multiple documents across different domains efficiently.
Args:
model: Base GLiNER2 model
documents_by_domain: Dict[domain_name, List[document_text]]
adapters: Dict[domain_name, adapter_path]
Returns:
Dict[domain_name, List[results]]
"""
results = {}
for domain, documents in documents_by_domain.items():
print(f"Processing {domain} domain ({len(documents)} documents)...")
# Load domain-specific adapter
model.load_adapter(adapters[domain])
# Process all documents for this domain
domain_results = []
for doc in documents:
result = model.extract_entities(doc, get_entity_types(domain))
domain_results.append(result)
results[domain] = domain_results
return results
def get_entity_types(domain):
"""Get entity types for each domain."""
types = {
"legal": ["company", "person", "law", "legal_action"],
"medical": ["disease", "drug", "symptom", "procedure"],
"support": ["customer", "order_id", "product", "issue"]
}
return types.get(domain, ["entity"])
# Example usage
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
documents_by_domain = {
"legal": [
"Apple Inc. filed suit against Samsung.",
"Microsoft acquired LinkedIn for $26B.",
],
"medical": [
"Patient has Type 2 Diabetes.",
"Prescribed Metformin 500mg daily.",
],
"support": [
"Issue with Order #12345 reported.",
"Refund processed for Order #98765.",
]
}
adapters = {
"legal": "./adapters/legal_adapter/final",
"medical": "./adapters/medical_adapter/final",
"support": "./adapters/support_adapter/final",
}
results = process_documents_by_domain(model, documents_by_domain, adapters)
# Results organized by domain
for domain, domain_results in results.items():
print(f"\n{domain.upper()} Results:")
for i, result in enumerate(domain_results, 1):
print(f" Document {i}: {len(result['entities'])} entities found")
```
## Real-World Use Cases
### Use Case 1: Multi-Tenant SaaS Platform
```python
class MultiTenantEntityExtractor:
"""Entity extraction service for multi-tenant platform."""
def __init__(self, base_model_name: str, tenant_adapters: dict):
"""
Args:
base_model_name: Path to base model
tenant_adapters: Dict mapping tenant_id to adapter_path
"""
self.model = GLiNER2.from_pretrained(base_model_name)
self.tenant_adapters = tenant_adapters
self.current_tenant = None
def extract_for_tenant(self, tenant_id: str, text: str, entity_types: list):
"""Extract entities for specific tenant."""
# Load tenant-specific adapter if needed
if self.current_tenant != tenant_id:
adapter_path = self.tenant_adapters.get(tenant_id)
if adapter_path:
self.model.load_adapter(adapter_path)
else:
self.model.unload_adapter() # Use base model
self.current_tenant = tenant_id
return self.model.extract_entities(text, entity_types)
# Setup
extractor = MultiTenantEntityExtractor(
base_model_name="fastino/gliner2-base-v1",
tenant_adapters={
"legal_firm_123": "./adapters/legal_adapter/final",
"hospital_456": "./adapters/medical_adapter/final",
"ecommerce_789": "./adapters/support_adapter/final",
}
)
# Usage
legal_result = extractor.extract_for_tenant(
"legal_firm_123",
"Apple sued Samsung",
["company"]
)
medical_result = extractor.extract_for_tenant(
"hospital_456",
"Patient has diabetes",
["disease"]
)
```
### Use Case 2: Document Classification Pipeline
```python
def classify_and_extract(document: str, model: GLiNER2, adapters: dict):
"""
Classify document type and extract relevant entities.
1. Classify document type using base model
2. Load appropriate domain adapter
3. Extract domain-specific entities
"""
# Step 1: Classify document type
doc_type_result = model.extract_entities(
document,
["legal_document", "medical_record", "support_ticket", "financial_report"]
)
# Determine document type
if doc_type_result['entities']:
doc_type = doc_type_result['entities'][0]['label']
doc_type = doc_type.replace("_document", "").replace("_record", "").replace("_ticket", "").replace("_report", "")
else:
doc_type = "general"
# Step 2: Load appropriate adapter
adapter_mapping = {
"legal": adapters.get("legal"),
"medical": adapters.get("medical"),
"support": adapters.get("support"),
"financial": adapters.get("financial"),
}
if doc_type in adapter_mapping and adapter_mapping[doc_type]:
model.load_adapter(adapter_mapping[doc_type])
# Step 3: Extract domain-specific entities
entity_types = {
"legal": ["company", "person", "law", "legal_action"],
"medical": ["disease", "drug", "symptom", "procedure", "dosage"],
"support": ["customer", "order_id", "product", "issue", "status"],
"financial": ["company", "amount", "date", "stock_symbol"],
}
entities = model.extract_entities(
document,
entity_types.get(doc_type, ["entity"])
)
return {
"document_type": doc_type,
"entities": entities['entities']
}
# Usage
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
adapters = {
"legal": "./adapters/legal_adapter/final",
"medical": "./adapters/medical_adapter/final",
"support": "./adapters/support_adapter/final",
}
document = "Patient John Smith diagnosed with Type 2 Diabetes on 2024-01-15."
result = classify_and_extract(document, model, adapters)
print(f"Document Type: {result['document_type']}")
print(f"Entities: {result['entities']}")
```
### Use Case 3: A/B Testing Adapters
```python
import random
class AdapterABTester:
"""A/B test different adapter versions."""
def __init__(self, base_model_name: str, adapter_variants: dict):
"""
Args:
adapter_variants: {"v1": path1, "v2": path2, ...}
"""
self.model = GLiNER2.from_pretrained(base_model_name)
self.adapter_variants = adapter_variants
self.results = {variant: [] for variant in adapter_variants}
def test_sample(self, text: str, entity_types: list, true_entities: list):
"""Test a sample with all adapter variants."""
sample_results = {}
for variant, adapter_path in self.adapter_variants.items():
# Load variant
self.model.load_adapter(adapter_path)
# Get predictions
pred = self.model.extract_entities(text, entity_types)
# Compute metrics
f1 = self.compute_f1(pred['entities'], true_entities)
sample_results[variant] = {
"predictions": pred['entities'],
"f1_score": f1
}
self.results[variant].append(f1)
return sample_results
def compute_f1(self, predicted, ground_truth):
"""Simple F1 computation (simplified for demo)."""
pred_set = {(e['text'], e['label']) for e in predicted}
true_set = {(e['text'], e['label']) for e in ground_truth}
if not pred_set and not true_set:
return 1.0
if not pred_set or not true_set:
return 0.0
tp = len(pred_set & true_set)
precision = tp / len(pred_set) if pred_set else 0
recall = tp / len(true_set) if true_set else 0
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def get_summary(self):
"""Get A/B test summary."""
summary = {}
for variant, scores in self.results.items():
if scores:
summary[variant] = {
"avg_f1": sum(scores) / len(scores),
"samples": len(scores)
}
return summary
# Usage
tester = AdapterABTester(
base_model_name="fastino/gliner2-base-v1",
adapter_variants={
"v1_r4": "./adapters/legal_v1_r4/final",
"v2_r8": "./adapters/legal_v2_r8/final",
"v3_r16": "./adapters/legal_v3_r16/final",
}
)
# Test samples
test_samples = [
{
"text": "Apple Inc. sued Samsung Electronics.",
"entity_types": ["company"],
"true_entities": [
{"text": "Apple Inc.", "label": "company"},
{"text": "Samsung Electronics", "label": "company"}
]
},
# More samples...
]
for sample in test_samples:
results = tester.test_sample(
sample["text"],
sample["entity_types"],
sample["true_entities"]
)
# Get summary
summary = tester.get_summary()
for variant, metrics in summary.items():
print(f"{variant}: Avg F1 = {metrics['avg_f1']:.3f} ({metrics['samples']} samples)")
```
## Best Practices
### 1. Choosing LoRA Hyperparameters
```python
# Small datasets (< 1K examples)
config = TrainingConfig(
lora_r=4, # Lower rank = fewer parameters
lora_alpha=8.0, # alpha = 2 * r
num_epochs=10,
)
# Medium datasets (1K-10K examples)
config = TrainingConfig(
lora_r=8, # Standard rank
lora_alpha=16.0,
num_epochs=5,
)
# Large datasets (> 10K examples)
config = TrainingConfig(
lora_r=16, # Higher rank = more capacity
lora_alpha=32.0,
num_epochs=3,
)
```
### 2. Target Module Selection
**Understanding Module Groups:**
GLiNER2 supports fine-grained control over which layers receive LoRA adaptation:
```python
# Option 1: Encoder only - all layers (query, key, value, dense)
# Use case: General domain adaptation, good starting point
# Memory: Moderate (~1-2% of model parameters)
lora_target_modules=["encoder"]
# Option 2: Encoder - attention layers only
# Use case: Very memory-constrained scenarios
# Memory: Low (~0.5-1% of model parameters)
lora_target_modules=["encoder.query", "encoder.key", "encoder.value"]
# Option 3: Encoder - FFN layers only
# Use case: Alternative to attention-only, sometimes better for certain tasks
# Memory: Low (~0.5-1% of model parameters)
lora_target_modules=["encoder.dense"]
# Option 4: Encoder + task heads
# Use case: When you want to adapt both representation and task-specific layers
# Memory: Moderate-High (~2-4% of model parameters)
lora_target_modules=["encoder", "span_rep", "classifier"]
# Option 5: All modules (DEFAULT)
# Use case: Maximum adaptation capacity, best performance
# Memory: High (~3-5% of model parameters)
lora_target_modules=["encoder", "span_rep", "classifier", "count_embed", "count_pred"]
```
**Recommendations:**
- **Start with encoder only** (`["encoder"]`) for most tasks
- **Add task heads** if performance is insufficient
- **Use attention-only** for extreme memory constraints
- **Use all modules** (default) when you need maximum performance
### 3. Adapter Organization
```
project/
├── base_model/
│ └── gliner2-base-v1/
├── adapters/
│ ├── legal/
│ │ ├── v1_r8/
│ │ │ └── final/
│ │ └── v2_r16/
│ │ └── final/
│ ├── medical/
│ │ └── final/
│ └── support/
│ └── final/
└── scripts/
├── train_adapters.py
└── evaluate_adapters.py
```
### 4. Version Control for Adapters
```python
# adapter_metadata.json
{
"legal_v1": {
"path": "./adapters/legal/v1_r8/final",
"base_model": "fastino/gliner2-base-v1",
"lora_r": 8,
"lora_alpha": 16.0,
"trained_on": "2024-01-15",
"training_samples": 5000,
"eval_f1": 0.87,
"notes": "Initial legal domain adapter"
},
"legal_v2": {
"path": "./adapters/legal/v2_r16/final",
"base_model": "fastino/gliner2-base-v1",
"lora_r": 16,
"lora_alpha": 32.0,
"trained_on": "2024-02-01",
"training_samples": 10000,
"eval_f1": 0.92,
"notes": "Improved with more data and higher rank"
}
}
```
### 5. Monitoring Adapter Performance
```python
def evaluate_adapter(model, adapter_path, test_data):
"""Evaluate adapter performance on test data."""
model.load_adapter(adapter_path)
results = {
"total": 0,
"correct": 0,
"precision_sum": 0,
"recall_sum": 0,
}
for sample in test_data:
pred = model.extract_entities(sample["text"], sample["entity_types"])
# Compute metrics
metrics = compute_metrics(pred['entities'], sample["true_entities"])
results["total"] += 1
results["precision_sum"] += metrics["precision"]
results["recall_sum"] += metrics["recall"]
avg_precision = results["precision_sum"] / results["total"]
avg_recall = results["recall_sum"] / results["total"]
f1 = 2 * avg_precision * avg_recall / (avg_precision + avg_recall)
return {
"precision": avg_precision,
"recall": avg_recall,
"f1": f1,
"samples": results["total"]
}
```
## Troubleshooting
### Issue 1: Adapter Not Affecting Predictions
**Symptom**: Predictions are the same with and without adapter.
**Solution**:
```python
# Check if adapter is actually loaded
print(f"Has adapter: {model.has_adapter}")
# Check LoRA layers
from gliner2.training.lora import LoRALayer
lora_count = sum(1 for m in model.modules() if isinstance(m, LoRALayer))
print(f"LoRA layers: {lora_count}")
# Should be > 0 if adapter is loaded
assert lora_count > 0, "No LoRA layers found!"
```
### Issue 2: Out of Memory During Training
**Solution**:
```python
config = TrainingConfig(
# Reduce batch size
batch_size=4, # Instead of 8
gradient_accumulation_steps=4, # Maintain effective batch size
# Use smaller LoRA rank
lora_r=4, # Instead of 8
# Enable mixed precision
fp16=True,
# Target only attention layers (fewer parameters)
lora_target_modules=["encoder.query", "encoder.key", "encoder.value"],
)
```
### Issue 3: Adapter File Not Found
**Solution**:
```python
import os
from gliner2.training.lora import LoRAAdapterConfig
adapter_path = "./adapters/legal_adapter/final"
# Check if path exists
if not os.path.exists(adapter_path):
print(f"Path does not exist: {adapter_path}")
# List available checkpoints
checkpoint_dir = "./adapters/legal_adapter"
if os.path.exists(checkpoint_dir):
checkpoints = os.listdir(checkpoint_dir)
print(f"Available checkpoints: {checkpoints}")
# Check if it's a valid adapter
if LoRAAdapterConfig.is_adapter_path(adapter_path):
print("Valid adapter path!")
config = LoRAAdapterConfig.load(adapter_path)
print(f"Adapter config: {config}")
else:
print("Not a valid adapter path!")
```
### Issue 4: Slow Adapter Switching
**Problem**: Switching between adapters takes too long.
**Solution**:
```python
# Pre-load adapters in memory (if you have enough RAM)
adapters = {}
for domain, path in adapter_paths.items():
# Load adapter weights into memory
adapters[domain] = load_adapter_to_memory(path)
# Fast switching from memory (not implemented in base API,
# but possible with custom caching layer)
```
## Summary
### Key Takeaways
**LoRA adapters** enable efficient multi-domain inference
**Training** is 2-3x faster than full fine-tuning
**Storage** savings of 65-95% compared to multiple full models
**Swapping** adapters takes < 1 second
**Domain specialization** improves accuracy on specific tasks
### Quick Reference
```python
# Training
config = TrainingConfig(
use_lora=True,
lora_r=8,
lora_alpha=16.0,
save_adapter_only=True,
)
trainer.train(train_data=examples)
# Loading
model = GLiNER2.from_pretrained("base-model")
model.load_adapter("./adapter/final")
# Swapping
model.load_adapter("./other_adapter/final")
# Unloading
model.unload_adapter()
# Checking
print(model.has_adapter)
print(model.adapter_config)
```
### Next Steps
1. **Train your first adapter** with domain-specific data
2. **Evaluate performance** on test set
3. **Experiment with hyperparameters** (rank, alpha, target modules)
4. **Deploy multiple adapters** for different use cases
5. **Monitor and iterate** based on real-world performance
For more information:
- LoRA Paper: https://arxiv.org/abs/2106.09685
- Implementation: `gliner2/training/lora.py`
- Tests: `tests/test_lora_adapters.py`
- Verification Guide: `LORA_VERIFICATION_TESTS.md`

View File

@ -0,0 +1,201 @@
# Tutorial 11: LoRA Adapter Switching/Routing
## Quick Start
Switch between domain-specific adapters during inference without reloading the base model.
```python
from gliner2 import GLiNER2
# Load base model once
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
# Load legal adapter
model.load_adapter("./legal_adapter")
legal_result = model.extract_entities("Apple sued Google", ["company"])
# Switch to medical adapter
model.load_adapter("./medical_adapter")
medical_result = model.extract_entities("Patient has diabetes", ["disease"])
# Use base model (no adapter)
model.unload_adapter()
base_result = model.extract_entities("Some text", ["entity"])
```
## Basic Usage
### Loading an Adapter
```python
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
model.load_adapter("./path/to/adapter")
```
The adapter path should point to a directory containing:
- `adapter_config.json`
- `adapter_weights.safetensors`
### Checking Adapter Status
```python
# Check if adapter is loaded
if model.has_adapter:
print("Adapter is loaded")
# Get adapter configuration
config = model.adapter_config
print(f"LoRA rank: {config.lora_r}")
```
### Unloading an Adapter
```python
# Remove adapter, use base model
model.unload_adapter()
```
## Switching Between Adapters
Adapters automatically swap when you call `load_adapter()`:
```python
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
# Legal domain
model.load_adapter("./legal_adapter")
result1 = model.extract_entities("Apple Inc. filed suit", ["company"])
# Medical domain (previous adapter auto-unloaded)
model.load_adapter("./medical_adapter")
result2 = model.extract_entities("Patient has diabetes", ["disease"])
# Support domain
model.load_adapter("./support_adapter")
result3 = model.extract_entities("Order #12345 issue", ["order_id"])
```
## Routing by Document Type
Route documents to appropriate adapters:
```python
def extract_with_routing(model, text, doc_type, adapters):
"""Route document to domain-specific adapter."""
adapter_path = adapters.get(doc_type)
if adapter_path:
model.load_adapter(adapter_path)
else:
model.unload_adapter() # Use base model
# Define entity types per domain
entity_types = {
"legal": ["company", "person", "law"],
"medical": ["disease", "drug", "symptom"],
"support": ["order_id", "customer", "issue"]
}
return model.extract_entities(
text,
entity_types.get(doc_type, ["entity"])
)
# Setup
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
adapters = {
"legal": "./legal_adapter",
"medical": "./medical_adapter",
"support": "./support_adapter"
}
# Use
result = extract_with_routing(
model,
"Apple sued Google",
"legal",
adapters
)
```
## Batch Processing by Domain
Process multiple documents efficiently:
```python
def process_by_domain(model, documents, adapters):
"""Process documents grouped by domain."""
results = {}
for domain, docs in documents.items():
# Load domain adapter
model.load_adapter(adapters[domain])
# Process all documents for this domain
results[domain] = [
model.extract_entities(doc, get_entity_types(domain))
for doc in docs
]
return results
# Example
documents = {
"legal": ["Apple sued Samsung", "Microsoft acquired LinkedIn"],
"medical": ["Patient has diabetes", "Prescribed Metformin"]
}
adapters = {
"legal": "./legal_adapter",
"medical": "./medical_adapter"
}
results = process_by_domain(model, documents, adapters)
```
## Simple Router Class
```python
class AdapterRouter:
"""Simple adapter router for multi-domain inference."""
def __init__(self, base_model_name, adapters):
self.model = GLiNER2.from_pretrained(base_model_name)
self.adapters = adapters
self.current_domain = None
def extract(self, text, domain, entity_types):
"""Extract entities using domain-specific adapter."""
# Load adapter if domain changed
if self.current_domain != domain:
adapter_path = self.adapters.get(domain)
if adapter_path:
self.model.load_adapter(adapter_path)
else:
self.model.unload_adapter()
self.current_domain = domain
return self.model.extract_entities(text, entity_types)
# Usage
router = AdapterRouter(
"fastino/gliner2-base-v1",
{
"legal": "./legal_adapter",
"medical": "./medical_adapter"
}
)
result = router.extract("Apple sued Google", "legal", ["company"])
```
## Summary
- **Load adapter**: `model.load_adapter(path)`
- **Unload adapter**: `model.unload_adapter()`
- **Check status**: `model.has_adapter`
- **Get config**: `model.adapter_config`
- **Auto-swap**: Loading a new adapter automatically unloads the previous one
For training adapters, see [Tutorial 10: LoRA Adapters](10-lora_adapters.md).

View File

@ -0,0 +1,372 @@
# GLiNER2 Entity Extraction Tutorial
Learn how to extract named entities from text using GLiNER2's flexible entity recognition capabilities.
## Table of Contents
- [Basic Entity Extraction](#basic-entity-extraction)
- [Entity Extraction with Descriptions](#entity-extraction-with-descriptions)
- [Single vs Multiple Entities](#single-vs-multiple-entities)
- [Custom Thresholds](#custom-thresholds)
- [Advanced Configuration](#advanced-configuration)
- [Domain-Specific Entities](#domain-specific-entities)
- [Best Practices](#best-practices)
## Basic Entity Extraction
### Simple Example
```python
from gliner2 import GLiNER2
# Load model
extractor = GLiNER2.from_pretrained("your-model-name")
# Extract common entities
text = "Apple Inc. CEO Tim Cook announced the new iPhone 15 in Cupertino, California on September 12, 2023."
results = extractor.extract_entities(
text,
["company", "person", "product", "location", "date"]
)
print(results)
# Output: {
# 'entities': {
# 'company': ['Apple Inc.'],
# 'person': ['Tim Cook'],
# 'product': ['iPhone 15'],
# 'location': ['Cupertino', 'California'],
# 'date': ['September 12, 2023']
# }
# }
```
### Using Schema Builder
```python
# Same extraction using schema
schema = extractor.create_schema().entities([
"company", "person", "product", "location", "date"
])
results = extractor.extract(text, schema)
```
## Entity Extraction with Descriptions
Descriptions significantly improve extraction accuracy by providing context.
```python
# Medical entity extraction
schema = extractor.create_schema().entities({
"drug": "Pharmaceutical drugs, medications, or treatment names",
"disease": "Medical conditions, illnesses, or disorders",
"symptom": "Clinical symptoms or patient-reported symptoms",
"dosage": "Medication amounts like '50mg' or '2 tablets daily'",
"organ": "Body parts or organs mentioned in medical context"
})
medical_text = """
Patient was prescribed Metformin 500mg twice daily for Type 2 Diabetes.
She reported fatigue and occasional dizziness. Liver function tests ordered.
"""
results = extractor.extract(medical_text, schema)
print(results)
# Output: {
# 'entities': {
# 'drug': ['Metformin'],
# 'disease': ['Type 2 Diabetes'],
# 'symptom': ['fatigue', 'dizziness'],
# 'dosage': ['500mg twice daily'],
# 'organ': ['Liver']
# }
# }
```
## Single vs Multiple Entities
Control whether to extract one or multiple entities per type.
### Multiple Entities (Default)
```python
# Default behavior - extracts all matching entities
schema = extractor.create_schema().entities(
["person", "organization"],
dtype="list" # Default
)
text = "Bill Gates and Steve Jobs founded Microsoft and Apple respectively."
results = extractor.extract(text, schema)
# Output: {
# 'entities': {
# 'person': ['Bill Gates', 'Steve Jobs'],
# 'organization': ['Microsoft', 'Apple']
# }
# }
```
### Single Entity per Type
```python
# Extract only the best match per entity type
schema = extractor.create_schema().entities(
["company", "ceo"],
dtype="str" # Single entity mode
)
text = "Apple CEO Tim Cook met with Microsoft CEO Satya Nadella."
results = extractor.extract(text, schema)
# Output: {
# 'entities': {
# 'company': 'Apple', # Just one, despite multiple in text
# 'ceo': 'Tim Cook' # Just one
# }
# }
```
## Custom Thresholds
Set confidence thresholds for precise control.
### Global Threshold
```python
# High-precision extraction
results = extractor.extract_entities(
text,
["email", "phone", "address"],
threshold=0.8 # High confidence required
)
```
### With Confidence Scores and Character Positions
You can include confidence scores and character-level start/end positions using `include_confidence` and `include_spans` parameters:
```python
# Extract entities with confidence scores
text = "Apple Inc. CEO Tim Cook announced iPhone 15 in Cupertino."
results = extractor.extract_entities(
text,
["company", "person", "product"],
include_confidence=True
)
print(results)
# Output: {
# 'entities': {
# 'company': [
# {'text': 'Apple Inc.', 'confidence': 0.95},
# {'text': 'Tim Cook', 'confidence': 0.92}
# ],
# 'product': [
# {'text': 'iPhone 15', 'confidence': 0.88}
# ]
# }
# }
# Extract with character positions (spans)
results = extractor.extract_entities(
text,
["company", "person"],
include_spans=True
)
print(results)
# Output: {
# 'entities': {
# 'company': [
# {'text': 'Apple Inc.', 'start': 0, 'end': 9}
# ],
# 'person': [
# {'text': 'Tim Cook', 'start': 15, 'end': 23}
# ]
# }
# }
# Extract with both confidence and spans
results = extractor.extract_entities(
text,
["company", "product"],
include_confidence=True,
include_spans=True
)
print(results)
# Output: {
# 'entities': {
# 'company': [
# {'text': 'Apple Inc.', 'confidence': 0.95, 'start': 0, 'end': 9}
# ],
# 'product': [
# {'text': 'confidence': 0.88, 'start': 15, 'end': 24}
# ]
# }
# }
```
**Note**: When `include_spans` is True, the output format changes:
- **Default** (both False): Returns simple text strings: `['Apple Inc.', 'Tim Cook']`
- **include_confidence=True**: Returns dicts with `{'text': '...', 'confidence': 0.95}`
- **include_spans=True**: Returns dicts with `{'text': '...', 'start': 0, 'end': 9}
- **Both True**: Returns dicts with `{'text': '...', 'confidence': 0.95, 'start': 0, 'end': 9}
### Per-Entity Thresholds
```python
# Different thresholds for different entities
schema = extractor.create_schema().entities({
"email": {
"description": "Email addresses",
"dtype": "list",
"threshold": 0.9 # Very high precision for emails
},
"phone": {
"description": "Phone numbers including mobile and landline",
"dtype": "list",
"threshold": 0.7 # Moderate threshold
},
"name": {
"description": "Person names",
"dtype": "list",
"threshold": 0.5 # Lower threshold for names
}
})
contact_text = "Contact John Doe at john.doe@email.com or call 555-1234."
results = extractor.extract(contact_text, schema, threshold=0.6) # Default threshold
```
## Advanced Configuration
### Mixed Configuration
```python
# Combine different entity configurations
schema = extractor.create_schema()
# Add simple entities
schema.entities(["date", "time", "currency"])
# Add entities with descriptions
schema.entities({
"technical_term": "Technical jargon or specialized terminology",
"metric": "Measurements, KPIs, or quantitative values"
})
# Add entities with full configuration
schema.entities({
"competitor": {
"description": "Competing companies or products",
"dtype": "list",
"threshold": 0.7
},
"revenue": {
"description": "Revenue figures or financial amounts",
"dtype": "str", # Only extract one
"threshold": 0.8
}
})
```
### Incremental Entity Addition
```python
# Build schema incrementally
schema = extractor.create_schema()
# Add entities in stages
schema.entities(["person", "location"]) # Basic entities
schema.entities({"company": "Company or organization names"}) # With description
schema.entities({ # With full config
"financial_term": {
"description": "Financial instruments, metrics, or terminology",
"threshold": 0.75
}
})
```
## Domain-Specific Entities
### Legal Entities
```python
legal_schema = extractor.create_schema().entities({
"party": "Parties involved in legal proceedings (plaintiff, defendant, etc.)",
"law_firm": "Law firm or legal practice names",
"court": "Court names or judicial bodies",
"statute": "Legal statutes, laws, or regulations cited",
"case": "Legal case names or citations",
"judge": "Names of judges or magistrates",
"legal_term": "Legal terminology or concepts"
})
legal_text = """
In the case of Smith v. Jones, Judge Sarah Williams of the Superior Court
ruled that the defendant violated Section 15.2 of the Consumer Protection Act.
The plaintiff was represented by Miller & Associates.
"""
results = extractor.extract(legal_text, legal_schema)
```
### Financial Entities
```python
finance_schema = extractor.create_schema().entities({
"ticker": "Stock ticker symbols (e.g., AAPL, GOOGL)",
"financial_metric": "Financial metrics like P/E ratio, market cap",
"currency_amount": "Monetary values with currency symbols",
"percentage": "Percentage values (e.g., 5.2%, -3%)",
"financial_org": "Banks, investment firms, financial institutions",
"market_index": "Stock market indices (S&P 500, NASDAQ, etc.)"
})
finance_text = """
AAPL rose 3.5% to $185.50 after beating earnings expectations.
The company's P/E ratio of 28.5 attracted Goldman Sachs analysts.
The NASDAQ composite gained 1.2% for the day.
"""
results = extractor.extract(finance_text, finance_schema)
```
### Scientific Entities
```python
science_schema = extractor.create_schema().entities({
"chemical": "Chemical compounds or elements",
"organism": "Biological organisms, species names",
"gene": "Gene names or identifiers",
"measurement": "Scientific measurements with units",
"research_method": "Research techniques or methodologies",
"institution": "Universities or research institutions"
})
science_text = """
Researchers at MIT discovered that the BRCA1 gene mutation increases
cancer risk by 70%. The study used CRISPR-Cas9 to modify DNA sequences
in Mus musculus specimens, measuring tumor growth in millimeters.
"""
results = extractor.extract(science_text, science_schema)
```
## Best Practices
### 1. Use Descriptive Entity Names
```python
# Good - Clear, specific entity types
schema.entities(["drug_name", "medical_device", "procedure_name"])
# Less ideal - Too generic
schema.entities(["thing", "item", "stuff"])
```
### 2. Provide Context with Descriptions
```python
# Good - Clear descriptions
schema.entities({
"acquisition_company": "Company that is acquiring another company",
"target_company": "Company being acquired",
"acquisition_price": "Purchase price or valuation of acquisition"
})
# Less ideal - No context
schema.entities(["company1", "company2", "price"])
```

View File

@ -0,0 +1,504 @@
# GLiNER2 JSON Structure Extraction Tutorial
Learn how to extract complex structured data from text using GLiNER2's hierarchical extraction capabilities.
## Table of Contents
- [Quick API with extract_json](#quick-api-with-extract_json)
- [Field Types and Specifications](#field-types-and-specifications)
- [Multiple Instances](#multiple-instances)
- [Schema Builder (Multi-Task)](#schema-builder-multi-task)
- [Real-World Examples](#real-world-examples)
- [Best Practices](#best-practices)
## Quick API with extract_json
For structure-only extraction, use the `extract_json()` method with the simple dictionary format:
### Basic Structure Extraction
```python
from gliner2 import GLiNER2
# Load model
extractor = GLiNER2.from_pretrained("your-model-name")
# Simple product extraction
text = "The MacBook Pro costs $1999 and features M3 chip, 16GB RAM, and 512GB storage."
results = extractor.extract_json(
text,
{
"product": [
"name::str",
"price",
"features"
]
}
)
print(results)
# Output: {
# 'product': [{
# 'name': 'MacBook Pro',
# 'price': ['$1999'],
# 'features': ['M3 chip', '16GB RAM', '512GB storage']
# }]
# }
```
### Contact Information
```python
text = """
Contact: John Smith
Email: john@example.com
Phones: 555-1234, 555-5678
Address: 123 Main St, NYC
"""
results = extractor.extract_json(
text,
{
"contact": [
"name::str",
"email::str",
"phone::list",
"address"
]
}
)
# Output: {
# 'contact': [{
# 'name': 'John Smith',
# 'email': 'john@example.com',
# 'phone': ['555-1234', '555-5678'],
# 'address': ['123 Main St, NYC']
# }]
# }
```
## Field Types and Specifications
### Field Specification Format
Fields support flexible specifications using `::` separators:
```
"field_name::type::description"
"field_name::[choice1|choice2|choice3]::type::description"
"field_name::description" # defaults to list type
"field_name" # simple field, defaults to list
```
### String vs List Fields
```python
text = """
Tech Conference 2024 on June 15th in San Francisco.
Topics include AI, Machine Learning, and Cloud Computing.
Registration fee: $299 for early bird tickets.
"""
results = extractor.extract_json(
text,
{
"event": [
"name::str::Event or conference name",
"date::str::Event date",
"location::str",
"topics::list::Conference topics",
"registration_fee::str"
]
}
)
# Output: {
# 'event': [{
# 'name': 'Tech Conference 2024',
# 'date': 'June 15th',
# 'location': 'San Francisco',
# 'topics': ['AI', 'Machine Learning', 'Cloud Computing'],
# 'registration_fee': '$299'
# }]
# }
```
### Choice Fields (Classification within Structure)
```python
text = """
Reservation at Le Bernardin for 4 people on March 15th at 7:30 PM.
We'd prefer outdoor seating. Two guests are vegetarian and one is gluten-free.
"""
results = extractor.extract_json(
text,
{
"reservation": [
"restaurant::str::Restaurant name",
"date::str",
"time::str",
"party_size::[1|2|3|4|5|6+]::str::Number of guests",
"seating::[indoor|outdoor|bar]::str::Seating preference",
"dietary::[vegetarian|vegan|gluten-free|none]::list::Dietary restrictions"
]
}
)
# Output: {
# 'reservation': [{
# 'restaurant': 'Le Bernardin',
# 'date': 'March 15th',
# 'time': '7:30 PM',
# 'party_size': '4',
# 'seating': 'outdoor',
# 'dietary': ['vegetarian', 'gluten-free']
# }]
# }
```
## Multiple Instances
GLiNER2 automatically extracts ALL instances of a structure found in text:
### Multiple Transactions
```python
text = """
Recent transactions:
- Jan 5: Starbucks $5.50 (food)
- Jan 5: Uber $23.00 (transport)
- Jan 6: Amazon $156.99 (shopping)
"""
results = extractor.extract_json(
text,
{
"transaction": [
"date::str",
"merchant::str",
"amount::str",
"category::[food|transport|shopping|utilities]::str"
]
}
)
# Output: {
# 'transaction': [
# {'date': 'Jan 5', 'merchant': 'Starbucks', 'amount': '$5.50', 'category': 'food'},
# {'date': 'Jan 5', 'merchant': 'Uber', 'amount': '$23.00', 'category': 'transport'},
# {'date': 'Jan 6', 'merchant': 'Amazon', 'amount': '$156.99', 'category': 'shopping'}
# ]
# }
```
### Multiple Hotel Bookings
```python
text = """
Alice Brown booked the Hilton Downtown from March 10 to March 12. She selected a double room
for $340 total with breakfast and parking included.
Robert Taylor reserved The Grand Hotel, April 1 to April 5, suite at $1,200 total.
Amenities include breakfast, wifi, gym, and spa access.
"""
results = extractor.extract_json(
text,
{
"booking": [
"guest::str::Guest name",
"hotel::str::Hotel name",
"check_in::str",
"check_out::str",
"room_type::[single|double|suite|deluxe]::str",
"total_price::str",
"amenities::[breakfast|wifi|parking|gym|spa]::list"
]
}
)
# Output: {
# 'booking': [
# {
# 'guest': 'Alice Brown',
# 'hotel': 'Hilton Downtown',
# 'check_in': 'March 10',
# 'check_out': 'March 12',
# 'room_type': 'double',
# 'total_price': '$340',
# 'amenities': ['breakfast', 'parking']
# },
# {
# 'guest': 'Robert Taylor',
# 'hotel': 'The Grand Hotel',
# 'check_in': 'April 1',
# 'check_out': 'April 5',
# 'room_type': 'suite',
# 'total_price': '$1,200',
# 'amenities': ['breakfast', 'wifi', 'gym', 'spa']
# }
# ]
# }
```
## Schema Builder (Multi-Task)
Use `create_schema()` only when combining structured extraction with other tasks (entities, classification):
### Multi-Task Extraction
```python
# Use schema builder for multi-task scenarios
schema = (extractor.create_schema()
# Extract entities
.entities(["person", "company", "location"])
# Classify sentiment
.classification("sentiment", ["positive", "negative", "neutral"])
# Extract structured product info
.structure("product")
.field("name", dtype="str")
.field("price", dtype="str")
.field("features", dtype="list")
.field("category", dtype="str", choices=["electronics", "software", "service"])
)
text = "Apple CEO Tim Cook announced iPhone 15 for $999 with amazing new features. This is exciting!"
results = extractor.extract(text, schema)
# Output: {
# 'entities': {'person': ['Tim Cook'], 'company': ['Apple'], 'location': []},
# 'sentiment': 'positive',
# 'product': [{
# 'name': 'iPhone 15',
# 'price': '$999',
# 'features': ['amazing new features'],
# 'category': 'electronics'
# }]
# }
```
### Advanced Configuration
```python
schema = (extractor.create_schema()
.classification("urgency", ["low", "medium", "high"])
.structure("support_ticket")
.field("ticket_id", dtype="str", threshold=0.9) # High precision
.field("customer", dtype="str", description="Customer name")
.field("issue", dtype="str", description="Problem description")
.field("priority", dtype="str", choices=["low", "medium", "high", "urgent"])
.field("tags", dtype="list", choices=["bug", "feature", "support", "billing"])
)
```
## Examples
### Financial Transaction Processing
```python
text = """
Goldman Sachs processed a $2.5M equity trade for Tesla Inc. on March 15, 2024.
Commission: $1,250. Status: Completed.
"""
results = extractor.extract_json(
text,
{
"transaction": [
"broker::str::Financial institution",
"amount::str::Transaction amount",
"security::str::Stock or financial instrument",
"date::str::Transaction date",
"commission::str::Fees charged",
"status::[pending|completed|failed]::str",
"type::[equity|bond|option|future]::str"
]
}
)
# Output: {
# 'transaction': [{
# 'broker': 'Goldman Sachs',
# 'amount': '$2.5M',
# 'security': 'Tesla Inc.',
# 'date': 'March 15, 2024',
# 'commission': '$1,250',
# 'status': 'completed',
# 'type': 'equity'
# }]
# }
```
### Medical Prescription Extraction
```python
text = """
Patient: Sarah Johnson, 34, presented with chest pain.
Prescribed: Lisinopril 10mg daily, Metoprolol 25mg twice daily.
Follow-up scheduled for next Tuesday.
"""
results = extractor.extract_json(
text,
{
"patient": [
"name::str::Patient full name",
"age::str::Patient age",
"symptoms::list::Reported symptoms"
],
"prescription": [
"medication::str::Drug name",
"dosage::str::Dosage amount",
"frequency::str::How often to take"
]
}
)
# Output: {
# 'patient': [{
# 'name': 'Sarah Johnson',
# 'age': '34',
# 'symptoms': ['chest pain']
# }],
# 'prescription': [
# {'medication': 'Lisinopril', 'dosage': '10mg', 'frequency': 'daily'},
# {'medication': 'Metoprolol', 'dosage': '25mg', 'frequency': 'twice daily'}
# ]
# }
```
### E-commerce Order Processing
```python
text = """
Order #ORD-2024-001 for Alexandra Thompson
Items: Laptop Stand (2x $45.99), Wireless Mouse (1x $29.99), USB Hub (3x $35.50)
Subtotal: $228.46, Tax: $18.28, Total: $246.74
Status: Processing
"""
results = extractor.extract_json(
text,
{
"order": [
"order_id::str::Order number",
"customer::str::Customer name",
"items::list::Product names",
"quantities::list::Item quantities",
"unit_prices::list::Individual prices",
"subtotal::str",
"tax::str",
"total::str",
"status::[pending|processing|shipped|delivered]::str"
]
}
)
# Output: {
# 'order': [{
# 'order_id': 'ORD-2024-001',
# 'customer': 'Alexandra Thompson',
# 'items': ['Laptop Stand', 'Wireless Mouse', 'USB Hub'],
# 'quantities': ['2', '1', '3'],
# 'unit_prices': ['$45.99', '$29.99', '$35.50'],
# 'subtotal': '$228.46',
# 'tax': '$18.28',
# 'total': '$246.74',
# 'status': 'processing'
# }]
# }
```
## Confidence Scores and Character Positions
You can include confidence scores and character-level start/end positions for structured extraction:
```python
# Extract with confidence scores
text = "The MacBook Pro costs $1999 and features M3 chip, 16GB RAM, and 512GB storage."
results = extractor.extract_json(
text,
{
"product": [
"name::str",
"price",
"features"
]
},
include_confidence=True
)
# Output: {
# 'product': [{
# 'name': {'text': 'MacBook Pro', 'confidence': 0.95},
# 'price': [{'text': '$1999', 'confidence': 0.92}],
# 'features': [
# {'text': 'M3 chip', 'confidence': 0.88},
# {'text': '16GB RAM', 'confidence': 0.90},
# {'text': '512GB storage', 'confidence': 0.87}
# ]
# }]
# }
# Extract with character positions (spans)
results = extractor.extract_json(
text,
{
"product": [
"name::str",
"price"
]
},
include_spans=True
)
# Output: {
# 'product': [{
# 'name': {'text': 'MacBook Pro', 'start': 4, 'end': 15},
# 'price': [{'text': '$1999', 'start': 22, 'end': 27}]
# }]
# }
# Extract with both confidence and spans
results = extractor.extract_json(
text,
{
"product": [
"name::str",
"price",
"features"
]
},
include_confidence=True,
include_spans=True
)
# Output: {
# 'product': [{
# 'name': {'text': 'MacBook Pro', 'confidence': 0.95, 'start': 4, 'end': 15},
# 'price': [{'text': '$1999', 'confidence': 0.92, 'start': 22, 'end': 27}],
# 'features': [
# {'text': 'M3 chip', 'confidence': 0.88, 'start': 32, 'end': 39},
# {'text': '16GB RAM', 'confidence': 0.90, 'start': 41, 'end': 49},
# {'text': '512GB storage', 'confidence': 0.87, 'start': 55, 'end': 68}
# ]
# }]
# }
```
**Note**: When `include_spans` or `include_confidence` is True:
- **String fields** (`dtype="str"`): Return dicts with `{'text': '...', 'confidence': 0.95, 'start': 0, 'end': 5}` (or subset)
- **List fields** (`dtype="list"`): Return lists of dicts, each with text, confidence, and positions
- **Default** (both False): Returns simple strings or lists of strings
## Best Practices
### Data Types
- Use `::str` for single values (IDs, names, amounts)
- Use `::list` or default for multiple values (features, items, tags)
- Use choices `[opt1|opt2|opt3]` for standardized values
- Add descriptions for complex or domain-specific fields
### Quick Decision Guide
**Use `extract_json()`** for:
- Structure-only extraction
- Quick data parsing
- Single extraction task
**Use `create_schema().extract()`** for:
- Multi-task scenarios (entities + structures + classification)
- When you need entities or classification alongside structures
- Complex extraction pipelines

View File

@ -0,0 +1,357 @@
# GLiNER2 Combining Schemas Tutorial
## Table of Contents
- [Why Combine Schemas](#why-combine-schemas)
- [Basic Combinations](#basic-combinations)
- [Advanced Multi-Task Schemas](#advanced-multi-task-schemas)
- [Real-World Applications](#real-world-applications)
## Why Combine Schemas
Combining schemas allows you to:
- Extract multiple types of information in one pass
- Maintain context between different extraction tasks
- Improve efficiency by avoiding multiple model calls
- Build comprehensive information extraction pipelines
## Basic Combinations
### Entities + Classification
```python
from gliner2 import GLiNER2
extractor = GLiNER2.from_pretrained("your-model-name")
# Sentiment analysis with entity extraction
schema = (extractor.create_schema()
.entities(["person", "product", "company"])
.classification("sentiment", ["positive", "negative", "neutral"])
.classification("category", ["review", "news", "opinion"])
)
text = "Tim Cook announced that Apple's new iPhone is exceeding sales expectations."
results = extractor.extract(text, schema)
# Output: {
# 'entities': {
# 'person': ['Tim Cook'],
# 'product': ['iPhone'],
# 'company': ['Apple']
# },
# 'sentiment': 'positive',
# 'category': 'news'
# }
```
### Entities + Structures
```python
schema = (extractor.create_schema()
.entities({
"person": "Names of people mentioned",
"date": "Dates and time references"
})
.structure("appointment")
.field("patient", dtype="str")
.field("doctor", dtype="str")
.field("date")
.field("time")
.field("type", dtype="str", choices=["checkup", "followup", "consultation"])
)
text = """
Dr. Sarah Johnson confirmed the appointment with John Smith for
March 15th at 2:30 PM. This will be a follow-up consultation
regarding his previous visit on February 1st.
"""
results = extractor.extract(text, schema)
```
### Classification + Structures
```python
schema = (extractor.create_schema()
.classification("email_type",
["order_confirmation", "shipping_update", "promotional", "support"])
.classification("priority", ["urgent", "normal", "low"])
.structure("order_info")
.field("order_number", dtype="str")
.field("items")
.field("total", dtype="str")
.field("status", dtype="str",
choices=["pending", "processing", "shipped", "delivered"])
)
```
## Advanced Multi-Task Schemas
### Complete Document Analysis
```python
# Comprehensive invoice extraction
invoice_schema = (extractor.create_schema()
# Document classification
.classification("document_type",
["invoice", "credit_note", "purchase_order", "receipt"])
.classification("payment_status",
["paid", "unpaid", "partial", "overdue"])
# Key entities
.entities({
"company": "Company names (buyer or seller)",
"person": "Contact person names",
"date": "Important dates",
"amount": "Monetary amounts"
})
# Structured information
.structure("invoice_header")
.field("invoice_number", dtype="str")
.field("issue_date", dtype="str")
.field("due_date", dtype="str")
.field("vendor_name", dtype="str")
.field("customer_name", dtype="str")
.structure("line_item")
.field("description", dtype="str")
.field("quantity")
.field("unit_price")
.field("amount")
.field("tax_rate", dtype="str", choices=["0%", "5%", "10%", "20%"])
.structure("payment_info")
.field("method", dtype="str",
choices=["bank_transfer", "credit_card", "check", "cash"])
.field("terms", description="Payment terms like NET30")
.field("bank_details", dtype="list")
)
```
### Customer Feedback Analysis
```python
feedback_schema = (extractor.create_schema()
# Overall classifications
.classification("sentiment", ["positive", "negative", "neutral", "mixed"])
.classification("intent", {
"complaint": "Customer expressing dissatisfaction",
"compliment": "Customer expressing satisfaction",
"suggestion": "Customer providing improvement ideas",
"question": "Customer asking for information"
}, multi_label=True)
# Extract mentioned entities
.entities({
"product": "Products or services mentioned",
"feature": "Specific features discussed",
"competitor": "Competing products mentioned",
"price_mention": "Price points or cost references"
})
# Structured feedback components
.structure("issue")
.field("problem", dtype="str")
.field("severity", dtype="str", choices=["critical", "major", "minor"])
.field("affected_area", dtype="list")
.structure("suggestion")
.field("improvement", dtype="str")
.field("benefit", description="Expected benefit of the suggestion")
)
```
### News Article Analysis
```python
news_schema = (extractor.create_schema()
# Article metadata
.classification("category",
["politics", "business", "technology", "sports", "entertainment"])
.classification("bias", ["left", "center", "right", "neutral"])
.classification("factuality", ["fact", "opinion", "analysis", "speculation"])
# Key entities
.entities({
"person": "People mentioned in the article",
"organization": "Companies, agencies, or groups",
"location": "Places, cities, or countries",
"event": "Named events or incidents"
})
# Structured content
.structure("quote")
.field("speaker", dtype="str")
.field("statement", dtype="str")
.field("context", description="Context of the quote")
.structure("claim")
.field("statement", dtype="str")
.field("source", dtype="str")
.field("evidence", dtype="list")
)
```
## Real-World Applications
### E-commerce Product Listing
```python
product_schema = (extractor.create_schema()
# Listing classification
.classification("condition", ["new", "used", "refurbished", "for_parts"])
.classification("listing_type", ["buy_now", "auction", "best_offer"])
# Extract key entities
.entities({
"brand": "Product brand or manufacturer",
"model": "Specific model name or number",
"color": "Product colors mentioned",
"size": "Size specifications"
})
# Product details
.structure("product")
.field("title", dtype="str")
.field("price", dtype="str")
.field("features", dtype="list")
.field("category", dtype="str")
# Shipping information
.structure("shipping")
.field("method", dtype="list",
choices=["standard", "express", "overnight", "international"])
.field("cost", dtype="str")
.field("delivery_time", description="Estimated delivery timeframe")
# Seller information
.structure("seller")
.field("name", dtype="str")
.field("rating", dtype="str")
.field("location", dtype="str")
)
```
### Healthcare Clinical Note
```python
clinical_schema = (extractor.create_schema()
# Note classification
.classification("visit_type",
["initial_consultation", "follow_up", "emergency", "routine_checkup"])
.classification("urgency", ["urgent", "routine", "elective"])
# Medical entities
.entities({
"symptom": "Patient reported symptoms",
"diagnosis": "Medical diagnoses or conditions",
"medication": "Prescribed or mentioned medications",
"procedure": "Medical procedures or tests",
"body_part": "Anatomical references"
})
# Patient information
.structure("patient_info")
.field("name", dtype="str")
.field("age", dtype="str")
.field("gender", dtype="str", choices=["male", "female", "other"])
.field("chief_complaint", dtype="str")
# Clinical findings
.structure("vital_signs")
.field("blood_pressure", dtype="str")
.field("heart_rate", dtype="str")
.field("temperature", dtype="str")
.field("respiratory_rate", dtype="str")
# Treatment plan
.structure("prescription")
.field("medication", dtype="str")
.field("dosage", dtype="str")
.field("frequency")
.field("duration")
.field("route", dtype="str", choices=["oral", "IV", "topical", "injection"])
)
```
### Legal Document Analysis
```python
legal_schema = (extractor.create_schema()
# Document classification
.classification("document_type",
["contract", "memorandum", "brief", "motion", "order"])
.classification("jurisdiction",
["federal", "state", "local", "international"])
# Legal entities
.entities({
"party": "Parties involved (plaintiff, defendant, etc.)",
"attorney": "Legal representatives",
"judge": "Judicial officers",
"statute": "Laws or regulations cited",
"case_citation": "Referenced legal cases"
})
# Contract terms
.structure("contract_term")
.field("clause_type", dtype="str",
choices=["payment", "delivery", "warranty", "liability", "termination"])
.field("obligation", dtype="str")
.field("party_responsible", dtype="str")
.field("deadline")
# Legal claims
.structure("claim")
.field("type", dtype="str")
.field("plaintiff", dtype="str")
.field("defendant", dtype="str")
.field("amount", dtype="str")
.field("basis", description="Legal basis for the claim")
)
```
## Using Confidence Scores and Character Positions with Combined Schemas
When using combined schemas, `include_confidence` and `include_spans` parameters apply to all extraction types:
```python
schema = (extractor.create_schema()
.entities(["person", "company"])
.classification("sentiment", ["positive", "negative", "neutral"])
.relations(["works_for"])
.structure("product")
.field("name", dtype="str")
.field("price", dtype="str")
)
text = "Tim Cook works for Apple. The iPhone 15 costs $999. This is exciting!"
results = extractor.extract(
text,
schema,
include_confidence=True,
include_spans=True
)
# Output: {
# 'entities': {
# 'person': [
# {'text': 'Tim Cook', 'confidence': 0.95, 'start': 0, 'end': 8}
# ],
# 'company': [
# {'text': 'Apple', 'confidence': 0.92, 'start': 20, 'end': 25}
# ]
# },
# 'sentiment': {'label': 'positive', 'confidence': 0.88},
# 'relation_extraction': {
# 'works_for': [{
# 'head': {'text': 'Tim Cook', 'confidence': 0.95, 'start': 0, 'end': 8},
# 'tail': {'text': 'Apple', 'confidence': 0.92, 'start': 20, 'end': 25}
# }]
# },
# 'product': [{
# 'name': {'text': 'iPhone 15', 'confidence': 0.90, 'start': 30, 'end': 39},
# 'price': {'text': '$999', 'confidence': 0.88, 'start': 46, 'end': 51}
# }]
# }
```
**Note**: The `include_confidence` and `include_spans` parameters work consistently across all extraction types (entities, classifications, relations, and structures) when using combined schemas.

View File

@ -0,0 +1,112 @@
# GLiNER2 Regex Validators
Regex validators filter extracted spans to ensure they match expected patterns, improving extraction quality and reducing false positives.
## Quick Start
```python
from gliner2 import GLiNER2, RegexValidator
extractor = GLiNER2.from_pretrained("your-model")
# Create validator and apply to field
email_validator = RegexValidator(r"^[\w\.-]+@[\w\.-]+\.\w+$")
schema = (extractor.create_schema()
.structure("contact")
.field("email", dtype="str", validators=[email_validator])
)
```
## RegexValidator Parameters
- **pattern**: Regex pattern (string or compiled Pattern)
- **mode**: `"full"` (exact match) or `"partial"` (substring match)
- **exclude**: `False` (keep matches) or `True` (exclude matches)
- **flags**: Regex flags like `re.IGNORECASE` (for string patterns only)
## Examples
### Email Validation
```python
email_validator = RegexValidator(r"^[\w\.-]+@[\w\.-]+\.\w+$")
text = "Contact: john@company.com, not-an-email, jane@domain.org"
# Output: ['john@company.com', 'jane@domain.org']
```
### Phone Numbers (US Format)
```python
phone_validator = RegexValidator(r"\(\d{3}\)\s\d{3}-\d{4}", mode="partial")
text = "Call (555) 123-4567 or 5551234567"
# Output: ['(555) 123-4567'] # Second number filtered out
```
### URLs Only
```python
url_validator = RegexValidator(r"^https?://", mode="partial")
text = "Visit https://example.com or www.site.com"
# Output: ['https://example.com'] # www.site.com filtered out
```
### Exclude Test Data
```python
no_test_validator = RegexValidator(r"^(test|demo|sample)", exclude=True, flags=re.IGNORECASE)
text = "Products: iPhone, Test Phone, Samsung Galaxy"
# Output: ['iPhone', 'Samsung Galaxy'] # Test Phone excluded
```
### Length Constraints
```python
length_validator = RegexValidator(r"^.{5,50}$") # 5-50 characters
text = "Names: Jo, Alexander, A Very Long Name That Exceeds Fifty Characters"
# Output: ['Alexander'] # Others filtered by length
```
### Multiple Validators
```python
# All validators must pass
username_validators = [
RegexValidator(r"^[a-zA-Z0-9_]+$"), # Alphanumeric + underscore
RegexValidator(r"^.{3,20}$"), # 3-20 characters
RegexValidator(r"^(?!admin)", exclude=True, flags=re.IGNORECASE) # No "admin"
]
schema = (extractor.create_schema()
.structure("user")
.field("username", dtype="str", validators=username_validators)
)
text = "Users: ab, john_doe, user@domain, admin, valid_user123"
# Output: ['john_doe', 'valid_user123']
```
## Common Patterns
| Use Case | Pattern | Mode |
|----------|---------|------|
| Email | `r"^[\w\.-]+@[\w\.-]+\.\w+$"` | full |
| Phone (US) | `r"\(\d{3}\)\s\d{3}-\d{4}"` | partial |
| URL | `r"^https?://"` | partial |
| Numbers only | `r"^\d+$"` | full |
| No spaces | `r"^\S+$"` | full |
| Min length | `r"^.{5,}$"` | full |
| Alphanumeric | `r"^[a-zA-Z0-9]+$"` | full |
## Best Practices
1. **Use specific patterns** - More specific = fewer false positives
2. **Test your regex** - Validate patterns before deployment
3. **Combine validators** - Chain multiple simple validators
4. **Consider case sensitivity** - Use `re.IGNORECASE` when needed
5. **Start simple** - Begin with basic patterns, refine as needed
## Performance Notes
- Validators run after span extraction but before formatting
- Failed validation simply excludes the span (no errors)
- Multiple validators use short-circuit evaluation (stops at first failure)
- Compiled patterns are cached automatically

View File

@ -0,0 +1,643 @@
# GLiNER2 Relation Extraction Tutorial
Learn how to extract relations between entities from text using GLiNER2's relation extraction capabilities.
## Table of Contents
- [Basic Relation Extraction](#basic-relation-extraction)
- [Multiple Relation Types](#multiple-relation-types)
- [Relation Extraction with Descriptions](#relation-extraction-with-descriptions)
- [Custom Thresholds](#custom-thresholds)
- [Batch Processing](#batch-processing)
- [Combining with Other Tasks](#combining-with-other-tasks)
- [Real-World Examples](#real-world-examples)
- [Best Practices](#best-practices)
## Basic Relation Extraction
### Simple Example
```python
from gliner2 import GLiNER2
# Load model
extractor = GLiNER2.from_pretrained("your-model-name")
# Extract relations
text = "John works for Apple Inc. and lives in San Francisco."
results = extractor.extract_relations(
text,
["works_for", "lives_in"]
)
print(results)
# Output: {
# 'relation_extraction': {
# 'works_for': [('John', 'Apple Inc.')],
# 'lives_in': [('John', 'San Francisco')]
# }
# }
```
### Using Schema Builder
```python
# Same extraction using schema
schema = extractor.create_schema().relations([
"works_for", "lives_in"
])
results = extractor.extract(text, schema)
```
### Understanding the Output Format
Relations are returned as tuples `(source, target)` grouped under the `relation_extraction` key. **All requested relation types are included in the output, even if no relations are found** (they appear as empty lists `[]`):
```python
text = "Alice manages the Engineering team. Bob reports to Alice."
results = extractor.extract_relations(
text,
["manages", "reports_to", "founded"] # Note: "founded" not found in text
)
# Output: {
# 'relation_extraction': {
# 'manages': [('Alice', 'Engineering team')],
# 'reports_to': [('Bob', 'Alice')],
# 'founded': [] # Empty list - relation type requested but not found
# }
# }
```
This ensures consistent output structure - all requested relation types will always be present in the results, making it easier to process the output programmatically.
## Multiple Relation Types
You can extract multiple relation types in a single call:
```python
text = """
Sarah founded TechCorp in 2020. She is married to Mike,
who works at Google. TechCorp is located in Seattle.
"""
results = extractor.extract_relations(
text,
["founded", "married_to", "works_at", "located_in"]
)
# Output: {
# 'relation_extraction': {
# 'founded': [('Sarah', 'TechCorp')],
# 'married_to': [('Sarah', 'Mike')],
# 'works_at': [('Mike', 'Google')],
# 'located_in': [('TechCorp', 'Seattle')]
# }
# }
```
### Multiple Instances per Relation Type
GLiNER2 automatically extracts all relation instances found in the text:
```python
text = """
John works for Microsoft. Mary works for Google.
Bob works for Apple. All three live in California.
"""
results = extractor.extract_relations(
text,
["works_for", "lives_in"]
)
# Output: {
# 'relation_extraction': {
# 'works_for': [
# ('John', 'Microsoft'),
# ('Mary', 'Google'),
# ('Bob', 'Apple')
# ],
# 'lives_in': [
# ('John', 'California'),
# ('Mary', 'California'),
# ('Bob', 'California')
# ]
# }
# }
```
## Relation Extraction with Descriptions
Providing descriptions helps improve extraction accuracy by clarifying the relation semantics:
```python
schema = extractor.create_schema().relations({
"works_for": "Employment relationship where person works at organization",
"founded": "Founding relationship where person created organization",
"acquired": "Acquisition relationship where company bought another company",
"located_in": "Geographic relationship where entity is in a location"
})
text = """
Elon Musk founded SpaceX in 2002. SpaceX is located in Hawthorne, California.
Tesla acquired SolarCity in 2016. Many engineers work for SpaceX.
"""
results = extractor.extract(text, schema)
```
### Advanced Configuration
```python
schema = extractor.create_schema().relations({
"works_for": {
"description": "Employment or professional relationship",
"threshold": 0.7 # Higher precision for employment relations
},
"located_in": {
"description": "Geographic containment relationship",
"threshold": 0.6 # Moderate threshold
},
"reports_to": {
"description": "Organizational hierarchy relationship",
"threshold": 0.8 # Very high precision
}
})
```
## Custom Thresholds
### Global Threshold
```python
# High-precision relation extraction
results = extractor.extract_relations(
text,
["acquired", "merged_with"],
threshold=0.8 # High confidence required
)
```
### Per-Relation Thresholds
```python
schema = extractor.create_schema().relations({
"acquired": {
"description": "Company acquisition relationship",
"threshold": 0.9 # Very high precision
},
"partnered_with": {
"description": "Partnership or collaboration relationship",
"threshold": 0.6 # Moderate threshold
},
"competes_with": {
"description": "Competitive relationship",
"threshold": 0.5 # Lower threshold for implicit relations
}
})
```
### With Confidence Scores and Character Positions
You can include confidence scores and character-level start/end positions for relation extractions:
```python
# Extract relations with confidence scores
text = "John works for Apple Inc. and lives in San Francisco."
results = extractor.extract_relations(
text,
["works_for", "lives_in"],
include_confidence=True
)
print(results)
# Output: {
# 'relation_extraction': {
# 'works_for': [{
# 'head': {'text': 'John', 'confidence': 0.95},
# 'tail': {'text': 'Apple Inc.', 'confidence': 0.92}
# }],
# 'lives_in': [{
# 'head': {'text': 'John', 'confidence': 0.94},
# 'tail': {'text': 'San Francisco', 'confidence': 0.91}
# }]
# }
# }
# Extract with character positions (spans)
results = extractor.extract_relations(
text,
["works_for", "lives_in"],
include_spans=True
)
print(results)
# Output: {
# 'relation_extraction': {
# 'works_for': [{
# 'head': {'text': 'John', 'start': 0, 'end': 4},
# 'tail': {'text': 'Apple Inc.', 'start': 15, 'end': 25}
# }],
# 'lives_in': [{
# 'head': {'text': 'John', 'start': 0, 'end': 4},
# 'tail': {'text': 'San Francisco', 'start': 33, 'end': 46}
# }]
# }
# }
# Extract with both confidence and spans
results = extractor.extract_relations(
text,
["works_for", "lives_in"],
include_confidence=True,
include_spans=True
)
print(results)
# Output: {
# 'relation_extraction': {
# 'works_for': [{
# 'head': {'text': 'John', 'confidence': 0.95, 'start': 0, 'end': 4},
# 'tail': {'text': 'Apple Inc.', 'confidence': 0.92, 'start': 15, 'end': 25}
# }],
# 'lives_in': [{
# 'head': {'text': 'John', 'confidence': 0.94, 'start': 0, 'end': 4},
# 'tail': {'text': 'San Francisco', 'confidence': 0.91, 'start': 33, 'end': 46}
# }]
# }
# }
```
**Note**: When `include_spans` or `include_confidence` is True, relations are returned as dictionaries with `head` and `tail` keys, each containing the extracted text along with optional confidence scores and character positions. When both are False (default), relations are returned as simple tuples `(head, tail)`.
## Batch Processing
Process multiple texts efficiently:
```python
texts = [
"John works for Microsoft and lives in Seattle.",
"Sarah founded TechStartup in 2020.",
"Bob reports to Alice at Google."
]
results = extractor.batch_extract_relations(
texts,
["works_for", "founded", "reports_to", "lives_in"],
batch_size=8
)
# Output: [
# {
# 'relation_extraction': {
# 'works_for': [('John', 'Microsoft')],
# 'lives_in': [('John', 'Seattle')],
# 'founded': [], # Not found in first text
# 'reports_to': [] # Not found in first text
# }
# },
# {
# 'relation_extraction': {
# 'works_for': [], # Not found in second text
# 'founded': [('Sarah', 'TechStartup')],
# 'reports_to': [], # Not found in second text
# 'lives_in': [] # Not found in second text
# }
# },
# {
# 'relation_extraction': {
# 'works_for': [('Alice', 'Google')],
# 'reports_to': [('Bob', 'Alice')],
# 'founded': [], # Not found in third text
# 'lives_in': [] # Not found in third text
# }
# }
# ]
```
**Note**: All requested relation types appear in each result, even if empty. This ensures consistent structure across all batch results, making it easier to process programmatically.
## Combining with Other Tasks
Relation extraction can be combined with entity extraction, classification, and structured extraction:
### Relations + Entities
```python
schema = (extractor.create_schema()
.entities(["person", "organization", "location"])
.relations(["works_for", "located_in"])
)
text = "Tim Cook works for Apple Inc., which is located in Cupertino, California."
results = extractor.extract(text, schema)
# Output: {
# 'entities': {
# 'person': ['Tim Cook'],
# 'organization': ['Apple Inc.'],
# 'location': ['Cupertino', 'California']
# },
# 'relation_extraction': {
# 'works_for': [('Tim Cook', 'Apple Inc.')],
# 'located_in': [('Apple Inc.', 'Cupertino')]
# }
# }
```
### Relations + Classification + Structures
```python
schema = (extractor.create_schema()
.classification("document_type", ["news", "report", "announcement"])
.entities(["person", "company"])
.relations(["works_for", "acquired"])
.structure("event")
.field("date", dtype="str")
.field("description", dtype="str")
)
text = """
BREAKING: Microsoft announced today that it acquired GitHub.
Satya Nadella, CEO of Microsoft, confirmed the deal.
The acquisition was finalized on October 26, 2018.
"""
results = extractor.extract(text, schema)
```
## Real-World Examples
### Organizational Relationships
```python
org_schema = extractor.create_schema().relations({
"reports_to": "Direct reporting relationship in organizational hierarchy",
"manages": "Management relationship where person manages team/department",
"works_for": "Employment relationship",
"founded": "Founding relationship",
"acquired": "Company acquisition relationship"
})
text = """
Sundar Pichai is the CEO of Google. He reports to the board of directors.
Google acquired YouTube in 2006. Many engineers work for Google.
"""
results = extractor.extract(text, org_schema)
# Output: {
# 'relation_extraction': {
# 'reports_to': [('Sundar Pichai', 'board of directors')],
# 'works_for': [('engineers', 'Google')],
# 'acquired': [('Google', 'YouTube')]
# }
# }
```
### Medical Relationships
```python
medical_schema = extractor.create_schema().relations({
"treats": "Medical treatment relationship between doctor and patient",
"prescribed_for": "Prescription relationship between medication and condition",
"causes": "Causal relationship between condition and symptom",
"located_in": "Anatomical location relationship"
})
text = """
Dr. Smith treats patients with diabetes. Metformin is prescribed for Type 2 Diabetes.
High blood sugar causes frequent urination. The pancreas is located in the abdomen.
"""
results = extractor.extract(text, medical_schema)
```
### Financial Relationships
```python
finance_schema = extractor.create_schema().relations({
"invested_in": "Investment relationship between investor and company",
"acquired": "Company acquisition relationship",
"merged_with": "Merger relationship between companies",
"owns": "Ownership relationship"
})
text = """
SoftBank invested in Uber in 2018. Microsoft acquired LinkedIn in 2016.
Disney merged with 21st Century Fox. Berkshire Hathaway owns Geico.
"""
results = extractor.extract(text, finance_schema)
```
### Geographic Relationships
```python
geo_schema = extractor.create_schema().relations({
"located_in": "Geographic containment (city in country, etc.)",
"borders": "Geographic adjacency relationship",
"capital_of": "Capital city relationship",
"flows_through": "River or waterway relationship"
})
text = """
Paris is the capital of France. France borders Germany and Spain.
The Seine flows through Paris. Paris is located in France.
"""
results = extractor.extract(text, geo_schema)
```
### Family Relationships
```python
family_schema = extractor.create_schema().relations({
"married_to": "Marriage relationship",
"parent_of": "Parent-child relationship",
"sibling_of": "Sibling relationship",
"related_to": "General family relationship"
})
text = """
John is married to Mary. They are parents of two children: Alice and Bob.
Alice and Bob are siblings. Mary is related to her sister Sarah.
"""
results = extractor.extract(text, family_schema)
```
### Academic Relationships
```python
academic_schema = extractor.create_schema().relations({
"authored": "Publication relationship between author and paper",
"cited": "Citation relationship between papers",
"supervised": "Academic supervision relationship",
"affiliated_with": "Institutional affiliation relationship"
})
text = """
Dr. Johnson authored the paper on machine learning. The paper cited
previous work by Dr. Smith. Dr. Johnson supervises graduate students
at MIT, where she is affiliated with the Computer Science department.
"""
results = extractor.extract(text, academic_schema)
```
## Best Practices
### 1. Use Clear, Specific Relation Names
```python
# Good - Clear and specific
schema.relations(["works_for", "reports_to", "manages"])
# Less ideal - Too generic
schema.relations(["related", "connected", "linked"])
```
### 2. Provide Descriptions for Ambiguous Relations
```python
# Good - Clear descriptions
schema.relations({
"works_for": "Employment relationship where person works at organization",
"consulted_for": "Consulting relationship where person provides services to organization"
})
# Less ideal - No context
schema.relations(["works_for", "consulted_for"])
```
### 3. Set Appropriate Thresholds
```python
# High precision for critical relations
schema.relations({
"acquired": {
"description": "Company acquisition",
"threshold": 0.9 # Very high precision
},
"partnered_with": {
"description": "Partnership relationship",
"threshold": 0.6 # Moderate threshold
}
})
```
### 4. Combine with Entity Extraction
```python
# Extract both entities and relations for better context
schema = (extractor.create_schema()
.entities(["person", "organization"])
.relations(["works_for", "founded"])
)
```
### 5. Use Batch Processing for Multiple Texts
```python
# Efficient batch processing
results = extractor.batch_extract_relations(
texts,
relation_types,
batch_size=8 # Adjust based on your hardware
)
```
### 6. Handle Multiple Instances
```python
# GLiNER2 automatically extracts all instances
text = "John works for Apple. Mary works for Google. Bob works for Microsoft."
results = extractor.extract_relations(text, ["works_for"])
# Returns all three work relationships
```
### 7. Handle Empty Relations
All requested relation types are always included in the output, even if empty:
```python
results = extractor.extract_relations(
"John works for Microsoft.",
["works_for", "founded", "acquired"]
)
# Output: {
# 'relation_extraction': {
# 'works_for': [('John', 'Microsoft')],
# 'founded': [], # Empty - not found in text
# 'acquired': [] # Empty - not found in text
# }
# }
# This makes it easy to check for relations programmatically:
for rel_type, rels in results['relation_extraction'].items():
if rels: # Non-empty
print(f"Found {len(rels)} {rel_type} relations")
else: # Empty
print(f"No {rel_type} relations found")
```
### 7. Validate Relation Direction
Relations are directional tuples `(source, target)`:
- `works_for`: (person, organization)
- `located_in`: (entity, location)
- `reports_to`: (subordinate, manager)
- `manages`: (manager, team)
Make sure your relation names match the expected direction.
## Common Use Cases
### Knowledge Graph Construction
```python
# Extract entities and relations for knowledge graph
schema = (extractor.create_schema()
.entities(["person", "organization", "location", "product"])
.relations([
"works_for", "founded", "located_in", "created",
"acquired", "partnered_with"
])
)
# Process documents to build knowledge graph
documents = [...] # Your documents
all_relations = []
all_entities = []
for doc in documents:
results = extractor.extract(doc, schema)
all_relations.append(results.get("relation_extraction", {}))
all_entities.append(results.get("entities", {}))
```
### Relationship Analysis
```python
# Analyze organizational structures
org_texts = [...] # Organizational documents
results = extractor.batch_extract_relations(
org_texts,
["reports_to", "manages", "works_for", "collaborates_with"],
batch_size=8
)
# Analyze relationship patterns
for result in results:
relations = result.get("relation_extraction", {})
# Process relations for analysis
```
### Document Understanding
```python
# Comprehensive document understanding
schema = (extractor.create_schema()
.classification("document_type", ["contract", "report", "email"])
.entities(["person", "organization", "date", "amount"])
.relations(["signed_by", "involves", "dated", "worth"])
.structure("contract_term")
.field("term", dtype="str")
.field("value", dtype="str")
)
# Extract all information types in one pass
results = extractor.extract(document_text, schema)
```

View File

@ -0,0 +1,514 @@
# GLiNER2 API Extractor
Use GLiNER2 through a cloud API without loading models locally. Perfect for production deployments, low-memory environments, or when you need instant access without GPU setup.
## Table of Contents
- [Getting Started](#getting-started)
- [Basic Usage](#basic-usage)
- [Entity Extraction](#entity-extraction)
- [Text Classification](#text-classification)
- [Structured Extraction](#structured-extraction)
- [Relation Extraction](#relation-extraction)
- [Combined Schemas](#combined-schemas)
- [Batch Processing](#batch-processing)
- [Confidence Scores](#confidence-scores)
- [Error Handling](#error-handling)
- [API vs Local](#api-vs-local)
## Getting Started
### Get Your API Key
1. Visit [gliner.pioneer.ai](https://gliner.pioneer.ai)
2. Sign up or log in to your account
3. Navigate to API Keys section
4. Generate a new API key
### Installation
```bash
pip install gliner2
```
### Set Your API Key
**Option 1: Environment Variable (Recommended)**
```bash
export PIONEER_API_KEY="your-api-key-here"
```
**Option 2: Pass Directly**
```python
extractor = GLiNER2.from_api(api_key="your-api-key-here")
```
## Basic Usage
```python
from gliner2 import GLiNER2
# Load from API (uses PIONEER_API_KEY environment variable)
extractor = GLiNER2.from_api()
# Use exactly like the local model!
results = extractor.extract_entities(
"Apple CEO Tim Cook announced the iPhone 15 in Cupertino.",
["company", "person", "product", "location"]
)
print(results)
# Output: {
# 'entities': {
# 'company': ['Apple'],
# 'person': ['Tim Cook'],
# 'product': ['iPhone 15'],
# 'location': ['Cupertino']
# }
# }
```
## Entity Extraction
### Simple Extraction
```python
extractor = GLiNER2.from_api()
text = "Elon Musk founded SpaceX in 2002 and Tesla in 2003."
results = extractor.extract_entities(
text,
["person", "company", "date"]
)
# Output: {
# 'entities': {
# 'person': ['Elon Musk'],
# 'company': ['SpaceX', 'Tesla'],
# 'date': ['2002', '2003']
# }
# }
```
### With Confidence Scores and Character Positions
You can include confidence scores and character-level start/end positions using `include_confidence` and `include_spans`:
```python
# With confidence only
results = extractor.extract_entities(
"Microsoft acquired LinkedIn for $26.2 billion.",
["company", "price"],
include_confidence=True
)
# Output: {
# 'entities': {
# 'company': [
# {'text': 'Microsoft', 'confidence': 0.98},
# {'text': 'LinkedIn', 'confidence': 0.97}
# ],
# 'price': [
# {'text': '$26.2 billion', 'confidence': 0.95}
# ]
# }
# }
# With character positions (spans) only
results = extractor.extract_entities(
"Microsoft acquired LinkedIn.",
["company"],
include_spans=True
)
# Output: {
# 'entities': {
# 'company': [
# {'text': 'Microsoft', 'start': 0, 'end': 9},
# {'text': 'LinkedIn', 'start': 18, 'end': 26}
# ]
# }
# }
# With both confidence and spans
results = extractor.extract_entities(
"Microsoft acquired LinkedIn for $26.2 billion.",
["company", "price"],
include_confidence=True,
include_spans=True
)
# Output: {
# 'entities': {
# 'company': [
# {'text': 'Microsoft', 'confidence': 0.98, 'start': 0, 'end': 9},
# {'text': 'LinkedIn', 'confidence': 0.97, 'start': 18, 'end': 26}
# ],
# 'price': [
# {'text': '$26.2 billion', 'confidence': 0.95, 'start': 32, 'end': 45}
# ]
# }
# }
```
### Custom Threshold
```python
# Only return high-confidence extractions
results = extractor.extract_entities(
text,
["person", "company"],
threshold=0.8 # Minimum 80% confidence
)
```
## Text Classification
### Single-Label Classification
```python
extractor = GLiNER2.from_api()
text = "I absolutely love this product! It exceeded all my expectations."
results = extractor.classify_text(
text,
{"sentiment": ["positive", "negative", "neutral"]}
)
# Output: {'sentiment': {'category': 'positive'}}
```
### Multi-Task Classification
```python
text = "Breaking: Major earthquake hits coastal city. Rescue teams deployed."
results = extractor.classify_text(
text,
{
"category": ["politics", "sports", "technology", "disaster", "business"],
"urgency": ["low", "medium", "high"]
}
)
# Output: {'category': 'disaster', 'urgency': 'high'}
```
## Structured Extraction
### Contact Information
```python
extractor = GLiNER2.from_api()
text = """
Contact John Smith at john.smith@email.com or call +1-555-123-4567.
He works as a Senior Engineer at TechCorp Inc.
"""
results = extractor.extract_json(
text,
{
"contact": [
"name::str::Full name of the person",
"email::str::Email address",
"phone::str::Phone number",
"job_title::str::Professional title",
"company::str::Company name"
]
}
)
# Output: {
# 'contact': [{
# 'name': 'John Smith',
# 'email': 'john.smith@email.com',
# 'phone': '+1-555-123-4567',
# 'job_title': 'Senior Engineer',
# 'company': 'TechCorp Inc.'
# }]
# }
```
### Product Information
```python
text = "iPhone 15 Pro Max - $1199, 256GB storage, Natural Titanium color"
results = extractor.extract_json(
text,
{
"product": [
"name::str",
"price::str",
"storage::str",
"color::str"
]
}
)
# Output: {
# 'product': [{
# 'name': 'iPhone 15 Pro Max',
# 'price': '$1199',
# 'storage': '256GB',
# 'color': 'Natural Titanium'
# }]
# }
```
## Relation Extraction
Extract relationships between entities as directional tuples (source, target).
### Basic Relation Extraction
```python
extractor = GLiNER2.from_api()
text = "John works for Apple Inc. and lives in San Francisco. Apple Inc. is located in Cupertino."
results = extractor.extract_relations(
text,
["works_for", "lives_in", "located_in"]
)
# Output: {
# 'relation_extraction': {
# 'works_for': [('John', 'Apple Inc.')],
# 'lives_in': [('John', 'San Francisco')],
# 'located_in': [('Apple Inc.', 'Cupertino')]
# }
# }
```
### With Descriptions
```python
text = "Elon Musk founded SpaceX in 2002. SpaceX is located in Hawthorne, California."
schema = extractor.create_schema().relations({
"founded": "Founding relationship where person created organization",
"located_in": "Geographic relationship where entity is in a location"
})
results = extractor.extract(text, schema)
# Output: {
# 'relation_extraction': {
# 'founded': [('Elon Musk', 'SpaceX')],
# 'located_in': [('SpaceX', 'Hawthorne, California')]
# }
# }
```
### Batch Relation Extraction
```python
texts = [
"John works for Microsoft and lives in Seattle.",
"Sarah founded TechStartup in 2020.",
"Bob reports to Alice at Google."
]
results = extractor.batch_extract_relations(
texts,
["works_for", "founded", "reports_to", "lives_in"]
)
# Returns list of relation extraction results for each text
```
## Combined Schemas
Combine entities, classification, relations, and structured extraction in a single call.
```python
extractor = GLiNER2.from_api()
text = """
Tech Review: The new MacBook Pro M3 is absolutely fantastic! Apple has outdone themselves.
I tested it in San Francisco last week. Tim Cook works for Apple, which is located in Cupertino.
Highly recommended for developers. Rating: 5 out of 5 stars.
"""
schema = (extractor.create_schema()
.entities(["company", "product", "location", "person"])
.classification("sentiment", ["positive", "negative", "neutral"])
.relations(["works_for", "located_in"])
.structure("review")
.field("product_name", dtype="str")
.field("rating", dtype="str")
.field("recommendation", dtype="str")
)
results = extractor.extract(text, schema)
# Output: {
# 'entities': {
# 'company': ['Apple'],
# 'product': ['MacBook Pro M3'],
# 'location': ['San Francisco', 'Cupertino'],
# 'person': ['Tim Cook']
# },
# 'sentiment': 'positive',
# 'relation_extraction': {
# 'works_for': [('Tim Cook', 'Apple')],
# 'located_in': [('Apple', 'Cupertino')]
# },
# 'review': [{
# 'product_name': 'MacBook Pro M3',
# 'rating': '5 out of 5 stars',
# 'recommendation': 'Highly recommended for developers'
# }]
# }
```
## Batch Processing
Process multiple texts efficiently in a single API call.
```python
extractor = GLiNER2.from_api()
texts = [
"Google's Sundar Pichai unveiled Gemini AI in Mountain View.",
"Microsoft CEO Satya Nadella announced Copilot at Build 2023.",
"Amazon's Andy Jassy revealed new AWS services in Seattle."
]
results = extractor.batch_extract_entities(
texts,
["company", "person", "product", "location"]
)
for i, result in enumerate(results):
print(f"Text {i+1}: {result}")
```
## Confidence Scores and Character Positions
### Entity Extraction with Confidence
```python
# Include confidence scores
results = extractor.extract_entities(
"Apple released the iPhone 15 in September 2023.",
["company", "product", "date"],
include_confidence=True
)
# Each entity includes: {'text': '...', 'confidence': 0.95}
```
### Entity Extraction with Character Positions
```python
# Include character-level start/end positions
results = extractor.extract_entities(
"Apple released the iPhone 15.",
["company", "product"],
include_spans=True
)
# Each entity includes: {'text': '...', 'start': 0, 'end': 5}
```
### Both Confidence and Positions
```python
# Include both confidence and character positions
results = extractor.extract_entities(
"Apple released the iPhone 15 in September 2023.",
["company", "product", "date"],
include_confidence=True,
include_spans=True
)
# Each entity includes: {'text': '...', 'confidence': 0.95, 'start': 0, 'end': 5}
```
### Raw Results (Advanced)
For full control over the extraction data:
```python
results = extractor.extract_entities(
"Apple CEO Tim Cook announced new products.",
["company", "person"],
format_results=False, # Get raw extraction data
include_confidence=True,
include_spans=True
)
# Returns tuples: (text, confidence, start_char, end_char)
```
## Error Handling
```python
from gliner2 import GLiNER2, GLiNER2APIError, AuthenticationError, ValidationError
try:
extractor = GLiNER2.from_api()
results = extractor.extract_entities(text, entity_types)
except AuthenticationError:
print("Invalid API key. Check your PIONEER_API_KEY.")
except ValidationError as e:
print(f"Invalid request: {e}")
except GLiNER2APIError as e:
print(f"API error: {e}")
```
### Connection Settings
```python
extractor = GLiNER2.from_api(
api_key="your-key",
timeout=60.0, # Request timeout (seconds)
max_retries=5 # Retry failed requests
)
```
## API vs Local
| Feature | API (`from_api()`) | Local (`from_pretrained()`) |
|---------|-------------------|----------------------------|
| Setup | Just API key | GPU/CPU + model download |
| Memory | ~0 MB | 2-8 GB+ |
| Latency | Network dependent | Faster for single texts |
| Batch | Optimized | Optimized |
| Cost | Per request | Free after setup |
| Offline | ❌ | ✅ |
| RegexValidator | ❌ | ✅ |
### When to Use API
- Production deployments without GPU
- Serverless functions (AWS Lambda, etc.)
- Quick prototyping
- Low-memory environments
- Mobile/edge applications
### When to Use Local
- High-volume processing
- Offline requirements
- Sensitive data (no network transfer)
- Need for RegexValidator
- Cost optimization at scale
## Seamless Switching
The API mirrors the local interface exactly, making switching trivial:
```python
# Development: Use API for quick iteration
extractor = GLiNER2.from_api()
# Production: Switch to local if needed
# extractor = GLiNER2.from_pretrained("your-model")
# Same code works with both!
results = extractor.extract_entities(text, entity_types)
```
## Limitations
The API currently does not support:
1. **RegexValidator** - Use local model for regex-based filtering
2. **Multi-schema batch** - Different schemas per text in batch (works but slower)
3. **Custom models** - API uses the default GLiNER2 model
## Best Practices
1. **Store API key securely** - Use environment variables, not hardcoded strings
2. **Handle errors gracefully** - Network issues can occur
3. **Use batch processing** - More efficient than individual calls
4. **Set appropriate timeouts** - Increase for large texts
5. **Cache results** - Avoid redundant API calls for same content

View File

@ -0,0 +1,630 @@
# GLiNER2 Training Dataset Formats
GLiNER2 uses JSONL format where each line contains an `input` and `output` field (or alternatively `text` and `schema`). The `input`/`text` is the text to process, and the `output`/`schema` is the schema with labels/annotations.
## Quick Format Reference
### General Structure
**Primary Format**:
```jsonl
{"input": "text to process", "output": {"schema_definition": "with_annotations"}}
```
**Alternative Format** (also supported):
```jsonl
{"text": "text to process", "schema": {"schema_definition": "with_annotations"}}
```
Both formats are equivalent - use whichever is more convenient for your workflow.
### Valid Output Schema Keys
| Key | Type | Required | Description |
|-----|------|----------|-------------|
| `entities` | `dict[str, list[str]]` | No | Entity type → list of entity mentions |
| `entity_descriptions` | `dict[str, str]` | No | Entity type → description |
| `classifications` | `list[dict]` | No | List of classification tasks |
| `json_structures` | `list[dict]` | No | List of structured data extractions |
| `json_descriptions` | `dict[str, dict[str, str]]` | No | Parent → field → description |
| `relations` | `list[dict]` | No | List of relation extractions |
### Classification Task Fields
| Field | Type | Required | Description |
|-------|------|----------|-------------|
| `task` | `str` | Yes | Task identifier |
| `labels` | `list[str]` | Yes | Available label options |
| `true_label` | `list[str]` or `str` | Yes | Correct label(s) |
| `multi_label` | `bool` | No | Enable multi-label classification |
| `prompt` | `str` | No | Custom prompt for the task |
| `examples` | `list[list[str]]` or `list[tuple[str, str]]` | No | Few-shot examples as [[input, output], ...] pairs. Internally converted to list of lists. |
| `label_descriptions` | `dict[str, str]` | No | Label → description mapping |
### Entity Fields Format
Entities use a simple dictionary where keys are entity types and values are lists of mentions:
| Component | Type | Required | Description |
|-----------|------|----------|-------------|
| Entity type (key) | `str` | Yes | Name of the entity type (e.g., "person", "location") |
| Entity mentions (value) | `list[str]` | Yes | List of entity text spans found in input |
**Format**: `{"entity_type": ["mention1", "mention2", ...]}`
### JSON Structure Fields Format
Each structure is a dictionary with a parent name as key and field definitions as value:
| Component | Type | Required | Description |
|-----------|------|----------|-------------|
| Parent name (key) | `str` | Yes | Name of the structure (e.g., "product", "contact") |
| Fields (value) | `dict` | Yes | Field name → field value mappings |
| Field value | `str` or `list[str]` or `dict` | Yes | String, list of strings, or choice dict |
| Choice dict | `dict` with `value` and `choices` | No | For classification-style fields |
**Format**: `[{"parent": {"field1": "value", "field2": ["list", "values"]}}]`
**Multiple Instances**: When the same parent appears multiple times, each instance is a separate dict in the list:
```jsonl
[{"hotel": {"name": "Hotel A", ...}}, {"hotel": {"name": "Hotel B", ...}}]
```
### Relation Fields Format
Relations use flexible field structures - you can use ANY field names (not just "head" and "tail"):
| Component | Type | Required | Description |
|-----------|------|----------|-------------|
| Relation name (key) | `str` | Yes | Name of the relation type (e.g., "works_for") |
| Fields (value) | `dict` | Yes | Field name → field value mappings |
| Field value | `str` or `list[str]` | Yes | String or list of strings |
**Standard Format**: `[{"relation_name": {"head": "entity1", "tail": "entity2"}}]`
**⚠️ Critical Constraint**: For a given relation type, the **first occurrence** defines the field structure:
- The first instance of "works_for" determines what fields ALL "works_for" instances must have
- All subsequent instances of the same relation type must use the same field names
- Different relation types can have different field structures
- **This consistency is enforced during validation** - inconsistent field structures will raise a `ValidationError`
**Example**: If first "works_for" has `{"head": "...", "tail": "..."}`, all other "works_for" instances must also have "head" and "tail" fields.
**Validation**: The `TrainingDataset.validate_relation_consistency()` method checks that all relation types have consistent field structures across the entire dataset.
---
## Alternative Input Formats
The training data loader supports multiple input formats:
1. **JSONL files**: `{"input": "...", "output": {...}}` or `{"text": "...", "schema": {...}}`
2. **Python API**: Use `InputExample` and `TrainingDataset` classes from `gliner2.training.data`
3. **Dict lists**: List of dictionaries in the same format as JSONL
All formats are automatically detected and converted to the internal format. See `gliner2.training.data.DataLoader_Factory` for details.
---
## 1. Classification Tasks
### Basic Single-Label Classification
```jsonl
{"input": "This movie is absolutely fantastic! I loved every minute of it.", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}]}}
{"input": "The service at this restaurant was terrible and the food was cold.", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["negative"]}]}}
{"input": "The weather today is okay, nothing special.", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["neutral"]}]}}
```
### Multi-label Classification
```jsonl
{"input": "This smartphone has an amazing camera but the battery life is poor.", "output": {"classifications": [{"task": "product_aspects", "labels": ["camera", "battery", "screen", "performance", "design"], "true_label": ["camera", "battery"], "multi_label": true}]}}
{"input": "Great performance and beautiful design!", "output": {"classifications": [{"task": "product_aspects", "labels": ["camera", "battery", "screen", "performance", "design"], "true_label": ["performance", "design"], "multi_label": true}]}}
```
### Classification with Label Descriptions
```jsonl
{"input": "Breaking: New AI model achieves human-level performance on reasoning tasks.", "output": {"classifications": [{"task": "news_category", "labels": ["technology", "politics", "sports", "entertainment"], "true_label": ["technology"], "label_descriptions": {"technology": "Articles about computers, AI, software, and tech innovations", "politics": "Government, elections, and political news", "sports": "Athletic events, teams, and competitions", "entertainment": "Movies, music, celebrities, and entertainment news"}}]}}
```
### Classification with Custom Prompts
```jsonl
{"input": "The patient shows signs of improvement after treatment.", "output": {"classifications": [{"task": "medical_assessment", "labels": ["improving", "stable", "declining", "critical"], "true_label": ["improving"], "prompt": "Assess the patient's medical condition based on the clinical notes."}]}}
```
### Classification with Few-Shot Examples
Few-shot examples are provided as a list of `[input, output]` pairs. Each example is a list/tuple with exactly 2 elements:
```jsonl
{"input": "This service exceeded all my expectations!", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"], "examples": [["Great product, highly recommend!", "positive"], ["Terrible experience, very disappointed.", "negative"], ["It's okay, nothing special.", "neutral"]]}]}}
```
**Format**: `"examples": [[input_text, output_label], [input_text, output_label], ...]`
Each example pair must have exactly 2 elements: the input text and the corresponding label.
### Classification with Both Examples and Descriptions
```jsonl
{"input": "The algorithm demonstrates linear time complexity.", "output": {"classifications": [{"task": "complexity", "labels": ["constant", "linear", "quadratic", "exponential"], "true_label": ["linear"], "examples": [["O(1) lookup time", "constant"], ["O(n) iteration", "linear"]], "label_descriptions": {"constant": "O(1) - fixed time regardless of input size", "linear": "O(n) - time scales linearly with input", "quadratic": "O(n²) - nested iterations", "exponential": "O(2ⁿ) - recursive branching"}}]}}
```
### Multiple Classification Tasks
```jsonl
{"input": "Exciting new smartphone with innovative features!", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}, {"task": "category", "labels": ["technology", "sports", "politics", "entertainment"], "true_label": ["technology"]}]}}
```
### true_label: String vs List Format
Both formats are supported - use list for consistency or string for brevity:
```jsonl
{"input": "Sample text A", "output": {"classifications": [{"task": "label", "labels": ["a", "b"], "true_label": ["a"]}]}}
{"input": "Sample text B", "output": {"classifications": [{"task": "label", "labels": ["a", "b"], "true_label": "b"}]}}
{"input": "This is great!", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": "positive"}]}}
```
**Note**:
- String format (`"true_label": "positive"`) and list format (`"true_label": ["positive"]`) are both valid for single-label classification
- Internally, string values are automatically converted to lists (`["positive"]`)
- For multi-label classification, always use list format: `"true_label": ["label1", "label2"]`
---
## 2. Named Entity Recognition (NER)
### Basic NER
```jsonl
{"input": "John Smith works at OpenAI in San Francisco and will visit London next month.", "output": {"entities": {"person": ["John Smith"], "organization": ["OpenAI"], "location": ["San Francisco", "London"]}}}
{"input": "Apple Inc. CEO Tim Cook announced the iPhone 15 release date.", "output": {"entities": {"person": ["Tim Cook"], "organization": ["Apple Inc."], "product": ["iPhone 15"]}}}
{"input": "The meeting on January 15, 2024 will be held at Microsoft headquarters.", "output": {"entities": {"date": ["January 15, 2024"], "organization": ["Microsoft"]}}}
```
### NER with Entity Descriptions
```jsonl
{"input": "Dr. Sarah Johnson prescribed Metformin 500mg twice daily for diabetes treatment.", "output": {"entities": {"person": ["Dr. Sarah Johnson"], "medication": ["Metformin"], "dosage": ["500mg"], "condition": ["diabetes"]}, "entity_descriptions": {"person": "Names of people mentioned in the text", "medication": "Names of drugs or pharmaceutical products", "dosage": "Specific amounts or dosages of medications", "condition": "Medical conditions or diseases"}}}
```
### NER with Multiple Instances of Same Entity Type
```jsonl
{"input": "Alice, Bob, and Charlie attended the meeting with David.", "output": {"entities": {"person": ["Alice", "Bob", "Charlie", "David"]}}}
```
### NER with Empty Entity Types
```jsonl
{"input": "The conference will be held next week.", "output": {"entities": {"person": [], "organization": [], "location": []}}}
```
### Partial NER (Some Entity Types Present)
```jsonl
{"input": "Microsoft announced new features.", "output": {"entities": {"organization": ["Microsoft"], "person": []}}}
```
---
## 3. JSON Structure Extraction
### Basic Structure with String Fields
```jsonl
{"input": "Contact John Doe at john.doe@email.com or call (555) 123-4567.", "output": {"json_structures": [{"contact": {"name": "John Doe", "email": "john.doe@email.com", "phone": "(555) 123-4567"}}]}}
```
### Structure with List Fields
```jsonl
{"input": "Product features include: wireless charging, water resistance, and face recognition.", "output": {"json_structures": [{"product": {"features": ["wireless charging", "water resistance", "face recognition"]}}]}}
```
### Structure with Mixed String and List Fields
```jsonl
{"input": "iPhone 15 costs $999 and comes in blue, black, and white colors.", "output": {"json_structures": [{"product": {"name": "iPhone 15", "price": "$999", "colors": ["blue", "black", "white"]}}]}}
```
### Multiple Instances of Same Structure Type
When the **same structure type** (parent name) appears multiple times in the text, each instance is a **separate dictionary** in the `json_structures` list:
```jsonl
{"input": "We have two hotels available: Hotel Paradise with 4 stars, pool, and wifi for $150/night, and Budget Inn with 2 stars and parking for $80/night.", "output": {"json_structures": [{"hotel": {"name": "Hotel Paradise", "stars": "4", "amenities": ["pool", "wifi"], "price": "$150/night"}}, {"hotel": {"name": "Budget Inn", "stars": "2", "amenities": ["parking"], "price": "$80/night"}}]}}
```
**Note**: Both instances use the same parent key "hotel" but are separate objects in the list. This is how you represent multiple occurrences of the same structure type.
Another example with three products:
```jsonl
{"input": "Available products: iPhone 15 for $999, MacBook Pro for $1999, and AirPods for $199.", "output": {"json_structures": [{"product": {"name": "iPhone 15", "price": "$999"}}, {"product": {"name": "MacBook Pro", "price": "$1999"}}, {"product": {"name": "AirPods", "price": "$199"}}]}}
```
### Structure with Classification Fields (Choices)
```jsonl
{"input": "Book a single room at Grand Hotel for 2 nights with breakfast included.", "output": {"json_structures": [{"booking": {"hotel": "Grand Hotel", "room_type": {"value": "single", "choices": ["single", "double", "suite"]}, "nights": "2", "meal_plan": {"value": "breakfast", "choices": ["none", "breakfast", "half-board", "full-board"]}}}]}}
```
### Structure with Multiple Choice Fields
```jsonl
{"input": "Order a large pepperoni pizza for delivery, extra cheese.", "output": {"json_structures": [{"order": {"size": {"value": "large", "choices": ["small", "medium", "large", "xlarge"]}, "type": {"value": "pepperoni", "choices": ["cheese", "pepperoni", "veggie", "supreme"]}, "method": {"value": "delivery", "choices": ["pickup", "delivery", "dine-in"]}, "extras": ["extra cheese"]}}]}}
```
### Structure with Field Descriptions
```jsonl
{"input": "Patient: Mary Wilson, Age: 45, diagnosed with hypertension, prescribed Lisinopril 10mg daily.", "output": {"json_structures": [{"medical_record": {"patient_name": "Mary Wilson", "age": "45", "diagnosis": "hypertension", "medication": "Lisinopril", "dosage": "10mg daily"}}], "json_descriptions": {"medical_record": {"patient_name": "Full name of the patient", "age": "Patient's age in years", "diagnosis": "Medical condition diagnosed", "medication": "Prescribed medication name", "dosage": "Medication dosage and frequency"}}}}
```
### Structure with Null/Empty Field Values
```jsonl
{"input": "Product name is Widget X. Price not available.", "output": {"json_structures": [{"product": {"name": "Widget X", "price": "", "description": ""}}]}}
```
### Structure with Some Fields Missing
```jsonl
{"input": "Contact Sarah at sarah@example.com", "output": {"json_structures": [{"contact": {"name": "Sarah", "email": "sarah@example.com", "phone": ""}}]}}
```
### Multiple Different Structure Types
```jsonl
{"input": "John Doe works at TechCorp. Product ABC costs $50 with free shipping.", "output": {"json_structures": [{"employee": {"name": "John Doe", "company": "TechCorp"}}, {"product": {"name": "ABC", "price": "$50", "shipping": "free"}}]}}
```
### Structure with Only List Fields
```jsonl
{"input": "Available colors: red, blue, green. Sizes: S, M, L, XL.", "output": {"json_structures": [{"options": {"colors": ["red", "blue", "green"], "sizes": ["S", "M", "L", "XL"]}}]}}
```
---
## 4. Relation Extraction
Relations use flexible field structures. While "head" and "tail" are common, you can use ANY field names.
**⚠️ Important**: The first occurrence of each relation type defines the field structure for ALL instances of that type.
### Basic Relation (Head and Tail)
```jsonl
{"input": "Alice manages the engineering team.", "output": {"relations": [{"manages": {"head": "Alice", "tail": "engineering team"}}]}}
{"input": "John works for Microsoft.", "output": {"relations": [{"works_for": {"head": "John", "tail": "Microsoft"}}]}}
```
### Multiple Instances - Same Field Structure
All instances of the same relation type MUST have the same fields (determined by first occurrence):
```jsonl
{"input": "Alice works for Google. Bob works for Microsoft. Charlie works for Amazon.", "output": {"relations": [{"works_for": {"head": "Alice", "tail": "Google"}}, {"works_for": {"head": "Bob", "tail": "Microsoft"}}, {"works_for": {"head": "Charlie", "tail": "Amazon"}}]}}
```
**Note**: All three "works_for" instances use the same fields (head, tail) as defined by the first occurrence.
### Multiple Different Relation Types
Different relation types can have different field structures:
```jsonl
{"input": "John works for Apple Inc. and lives in San Francisco. Apple Inc. is located in Cupertino.", "output": {"relations": [{"works_for": {"head": "John", "tail": "Apple Inc."}}, {"lives_in": {"head": "John", "tail": "San Francisco"}}, {"located_in": {"head": "Apple Inc.", "tail": "Cupertino"}}]}}
```
**Note**: Each relation type ("works_for", "lives_in", "located_in") can independently define its own field structure.
### Custom Field Names (Beyond Head/Tail)
You can use custom field names - the first occurrence defines what fields to use:
```jsonl
{"input": "Alice sent $100 to Bob. Charlie sent $50 to David.", "output": {"relations": [{"transaction": {"sender": "Alice", "recipient": "Bob", "amount": "$100"}}, {"transaction": {"sender": "Charlie", "recipient": "David", "amount": "$50"}}]}}
```
**Note**: First "transaction" uses sender/recipient/amount, so all "transaction" instances must use these same fields.
### Relations with Additional Fields
```jsonl
{"input": "John Smith is the CEO of TechCorp which is headquartered in Silicon Valley.", "output": {"relations": [{"employment": {"head": "John Smith", "tail": "TechCorp", "role": "CEO"}}, {"located_in": {"head": "TechCorp", "tail": "Silicon Valley"}}]}}
```
### Relations Combined with Entities
```jsonl
{"input": "Elon Musk founded SpaceX in 2002. SpaceX is located in Hawthorne.", "output": {"entities": {"person": ["Elon Musk"], "organization": ["SpaceX"], "location": ["Hawthorne"], "date": ["2002"]}, "relations": [{"founded": {"head": "Elon Musk", "tail": "SpaceX"}}, {"located_in": {"head": "SpaceX", "tail": "Hawthorne"}}]}}
```
### Empty Relations (Negative Example)
```jsonl
{"input": "The weather is nice today.", "output": {"relations": []}}
```
### Bidirectional Relations
```jsonl
{"input": "Alice and Bob are colleagues.", "output": {"relations": [{"colleague_of": {"head": "Alice", "tail": "Bob"}}, {"colleague_of": {"head": "Bob", "tail": "Alice"}}]}}
```
### Field Consistency: Relations vs JSON Structures
**Key Difference**:
- **Relations**: First occurrence defines field structure for ALL instances of that relation type
- All "works_for" relations must have same fields
- Enforced consistency per relation type
- **JSON Structures**: Fields can vary between instances of the same parent type
- Uses union of all fields across instances
- More flexible - instances can have different subsets of fields
**Example - Relations (Strict Consistency)**:
```jsonl
{"input": "Alice works for Google. Bob works for Microsoft.", "output": {"relations": [{"works_for": {"head": "Alice", "tail": "Google"}}, {"works_for": {"head": "Bob", "tail": "Microsoft"}}]}}
```
✓ Valid: Both "works_for" have same fields (head, tail)
**Example - JSON Structures (Flexible Fields)**:
```jsonl
{"input": "Product A costs $10. Product B costs $20 and weighs 5kg.", "output": {"json_structures": [{"product": {"name": "A", "price": "$10"}}, {"product": {"name": "B", "price": "$20", "weight": "5kg"}}]}}
```
✓ Valid: Second instance has extra "weight" field - this is allowed for json_structures
---
## 5. Combined Multi-Task Examples
### Entities + Classifications
```jsonl
{"input": "Apple Inc. announced record profits. This is great news for investors.", "output": {"entities": {"organization": ["Apple Inc."]}, "classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}]}}
```
### Entities + JSON Structures
```jsonl
{"input": "Contact John Doe at john@example.com. He works at TechCorp.", "output": {"entities": {"person": ["John Doe"], "organization": ["TechCorp"]}, "json_structures": [{"contact": {"name": "John Doe", "email": "john@example.com", "company": "TechCorp"}}]}}
```
### Entities + Relations
```jsonl
{"input": "Elon Musk founded SpaceX in 2002. SpaceX is located in Hawthorne.", "output": {"entities": {"person": ["Elon Musk"], "organization": ["SpaceX"], "location": ["Hawthorne"], "date": ["2002"]}, "relations": [{"founded": {"head": "Elon Musk", "tail": "SpaceX", "year": "2002"}}, {"located_in": {"head": "SpaceX", "tail": "Hawthorne"}}]}}
```
### Classifications + JSON Structures
```jsonl
{"input": "Premium subscription for $99/month includes unlimited access. Great value!", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}], "json_structures": [{"subscription": {"tier": "Premium", "price": "$99/month", "features": ["unlimited access"]}}]}}
```
### Entities + Classifications + JSON Structures
```jsonl
{"input": "Apple CEO Tim Cook unveiled iPhone 15 for $999. Analysts are optimistic.", "output": {"entities": {"person": ["Tim Cook"], "organization": ["Apple"], "product": ["iPhone 15"]}, "classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}], "json_structures": [{"product_announcement": {"company": "Apple", "product": "iPhone 15", "price": "$999", "presenter": "Tim Cook"}}]}}
```
### Entities + Relations + Classifications
```jsonl
{"input": "Sarah founded TechStart in 2020. The company is doing exceptionally well.", "output": {"entities": {"person": ["Sarah"], "organization": ["TechStart"], "date": ["2020"]}, "relations": [{"founded": {"head": "Sarah", "tail": "TechStart", "year": "2020"}}], "classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}]}}
```
### All Four Tasks Combined
```jsonl
{"input": "Breaking: Apple announces new iPhone 15 with improved camera. Analysts are optimistic about sales projections.", "output": {"entities": {"company": ["Apple"], "product": ["iPhone 15"]}, "classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}, {"task": "category", "labels": ["technology", "business", "sports", "entertainment"], "true_label": ["technology"]}], "json_structures": [{"news_article": {"company": "Apple", "product": "iPhone 15", "feature": "improved camera", "analyst_view": "optimistic"}}], "relations": [{"product_of": {"head": "iPhone 15", "tail": "Apple"}}]}}
```
### Multi-Task with Descriptions
```jsonl
{"input": "Dr. Johnson prescribed medication X for condition Y. Patient shows improvement.", "output": {"entities": {"person": ["Dr. Johnson"], "medication": ["medication X"], "condition": ["condition Y"]}, "entity_descriptions": {"person": "Healthcare provider names", "medication": "Prescribed drugs", "condition": "Medical conditions"}, "classifications": [{"task": "patient_status", "labels": ["improving", "stable", "declining"], "true_label": ["improving"], "label_descriptions": {"improving": "Patient condition getting better", "stable": "No change in condition", "declining": "Patient condition worsening"}}], "json_structures": [{"prescription": {"doctor": "Dr. Johnson", "medication": "medication X", "condition": "condition Y"}}], "json_descriptions": {"prescription": {"doctor": "Prescribing physician", "medication": "Prescribed drug name", "condition": "Diagnosed condition"}}}}
```
### Partial Multi-Task (Some Tasks Empty)
**Note**: While you can include empty dictionaries/lists for some tasks, at least one task must have content.
```jsonl
{"input": "The weather forecast predicts rain tomorrow.", "output": {"entities": {}, "classifications": [{"task": "weather", "labels": ["sunny", "rainy", "cloudy", "snowy"], "true_label": ["rainy"]}], "json_structures": []}}
```
This is valid because it has a classification task. However, if all tasks were empty, it would fail validation.
---
## 6. Format Edge Cases
### Completely Empty Output
**⚠️ Note**: Examples must have at least one task (entities, classifications, structures, or relations). Completely empty outputs are not valid training examples.
```jsonl
{"input": "Random text with no specific information.", "output": {"entities": {}, "classifications": [], "json_structures": [], "relations": []}}
```
This format will fail validation. Each example must contain at least one annotation.
### Empty Entities Dictionary
**⚠️ Note**: While an empty entities dictionary is syntactically valid, examples must have at least one task. If you only have empty entities, add at least one other task (classification, structure, or relation).
```jsonl
{"input": "The weather is nice today.", "output": {"entities": {}, "classifications": [{"task": "sentiment", "labels": ["positive", "negative"], "true_label": ["positive"]}]}}
```
### Empty Classifications List
**⚠️ Note**: While an empty classifications list is syntactically valid, examples must have at least one task. If you only have empty classifications, add at least one other task.
```jsonl
{"input": "Some generic text.", "output": {"classifications": [], "entities": {"location": ["text"]}}}
```
### Very Long Label Lists
```jsonl
{"input": "Sample text for many labels.", "output": {"classifications": [{"task": "topic", "labels": ["label1", "label2", "label3", "label4", "label5", "label6", "label7", "label8", "label9", "label10", "label11", "label12", "label13", "label14", "label15", "label16", "label17", "label18", "label19", "label20"], "true_label": ["label5"]}]}}
```
### Very Short Text
```jsonl
{"input": "Yes.", "output": {"classifications": [{"task": "response", "labels": ["yes", "no", "maybe"], "true_label": ["yes"]}]}}
{"input": "OK", "output": {"entities": {}}}
```
### Special Characters in Labels
```jsonl
{"input": "The C++ programming language.", "output": {"entities": {"programming_language": ["C++"]}}}
{"input": "Use the @ symbol for mentions.", "output": {"entities": {"symbol": ["@"]}}}
```
### Special Characters in Values
```jsonl
{"input": "Price is $1,299.99 (including tax).", "output": {"json_structures": [{"pricing": {"amount": "$1,299.99", "note": "(including tax)"}}]}}
```
### Unicode and Non-ASCII Characters
```jsonl
{"input": "Café Münchën serves crème brûlée.", "output": {"entities": {"location": ["Café Münchën"], "food": ["crème brûlée"]}}}
{"input": "东京 Tokyo is the capital.", "output": {"entities": {"location": ["东京", "Tokyo"]}}}
```
### Quotes and Escaping
```jsonl
{"input": "He said \"hello\" to me.", "output": {"entities": {"quote": ["\"hello\""]}}}
```
### Newlines in Text
```jsonl
{"input": "First line.\nSecond line.", "output": {"entities": {"text": ["First line", "Second line"]}}}
```
### Numbers as Strings vs Entity Names
```jsonl
{"input": "Room 123 on floor 4.", "output": {"json_structures": [{"location": {"room": "123", "floor": "4"}}]}}
```
### Boolean-like Values
```jsonl
{"input": "Status is active, notifications enabled.", "output": {"json_structures": [{"settings": {"status": "active", "notifications": "enabled"}}]}}
```
### Empty String Values
```jsonl
{"input": "Name: John, Age: unknown", "output": {"json_structures": [{"person": {"name": "John", "age": ""}}]}}
```
### Multiple Empty Lines in JSONL
```jsonl
{"input": "First example.", "output": {"entities": {"type": ["example"]}}}
{"input": "Second example.", "output": {"entities": {"type": ["example"]}}}
```
---
## Schema Component Reference
### entities
- **Type**: `dict[str, list[str]]`
- **Format**: `{"entity_type": ["mention1", "mention2", ...]}`
- **Example**: `{"person": ["John", "Alice"], "location": ["NYC"]}`
### entity_descriptions
- **Type**: `dict[str, str]`
- **Format**: `{"entity_type": "description text"}`
- **Example**: `{"person": "Names of people", "location": "Geographic places"}`
### classifications
- **Type**: `list[dict]`
- **Required fields**: `task`, `labels`, `true_label`
- **Optional fields**: `multi_label`, `prompt`, `examples`, `label_descriptions`
- **Example**: `[{"task": "sentiment", "labels": ["pos", "neg"], "true_label": ["pos"]}]`
### json_structures
- **Type**: `list[dict]`
- **Single instance**: `[{"parent_name": {"field1": "value1", "field2": ["list", "values"]}}]`
- **Multiple instances (same parent)**: `[{"parent": {...}}, {"parent": {...}}]` - Same parent key, separate dicts
- **Multiple types**: `[{"parent1": {...}}, {"parent2": {...}}]` - Different parent keys
- **Choice format**: `{"field": {"value": "selected", "choices": ["opt1", "opt2"]}}`
- **Example**: `[{"product": {"name": "Item", "price": "$10"}}, {"product": {"name": "Item2", "price": "$20"}}]`
### json_descriptions
- **Type**: `dict[str, dict[str, str]]`
- **Format**: `{"parent": {"field": "description"}}`
- **Example**: `{"product": {"name": "Product name", "price": "Cost in USD"}}`
### relations
- **Type**: `list[dict]`
- **Standard format**: `[{"relation_name": {"head": "entity1", "tail": "entity2"}}]`
- **With custom fields**: `[{"relation_name": {"sender": "A", "recipient": "B", "amount": "$100"}}]`
- **Example**: `[{"works_for": {"head": "John", "tail": "Company"}}, {"founded": {"head": "Alice", "tail": "StartupX"}}]`
- **⚠️ Field constraint**: First occurrence of each relation type defines field structure for ALL instances of that type
- **Note**: While "head" and "tail" are common, you can use ANY field names - just keep them consistent per relation type
---
## Tips for Dataset Creation
1. **Use diverse examples** to improve model generalization
2. **Include edge cases** - but remember each example must have at least one task
3. **Provide descriptions** when possible to improve accuracy
4. **Balance your classes** in classification tasks
5. **Use realistic text** that matches your target domain
6. **Include multiple instances** for JSON structures when applicable
7. **For negative examples**, include at least one task (e.g., empty entities but a classification, or empty classifications but entities)
8. **Mix task types** to train multi-task capabilities
9. **Use consistent formatting** for similar examples
10. **Include special characters** to ensure robust handling
11. **Validate your dataset** using `TrainingDataset.validate(strict=True)` to catch annotation errors early
12. **Check relation consistency** using `validate_relation_consistency()` to ensure all relation types have consistent field structures
## Validation Checklist
Make sure your JSONL file is valid by checking:
- [ ] Each line is valid JSON
- [ ] Required fields (`input`/`output` or `text`/`schema`) are present
- [ ] **At least one task is present** (entities, classifications, structures, or relations)
- [ ] Schema structure matches the expected format
- [ ] Entity spans exist in the input text (entities can be found in the input) - checked in strict validation mode
- [ ] Classification labels are from the defined label set
- [ ] `true_label` is a list or string (string format is converted to list internally)
- [ ] For multi-label classification, `multi_label` is set to `true` when multiple labels are provided
- [ ] JSON structure fields match between instances of the same parent (flexible - union of fields is used)
- [ ] **Relation field consistency**: All instances of the same relation type use the same field names (determined by first occurrence)
- [ ] No trailing commas in JSON objects
- [ ] Special characters are properly escaped
- [ ] File encoding is UTF-8
### Validation Modes
The implementation supports two validation modes:
- **Standard validation**: Checks format correctness, required fields, label consistency
- **Strict validation**: Additionally checks that entity mentions and relation values exist in the input text (case-insensitive substring matching)
Use strict validation during dataset creation to catch annotation errors early.

File diff suppressed because it is too large Load Diff