gliner2
This commit is contained in:
parent
afa907065b
commit
9b32c8dd29
97
packages/GLiNER2/.gitignore
vendored
Normal file
97
packages/GLiNER2/.gitignore
vendored
Normal file
@ -0,0 +1,97 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Ruff
|
||||
.ruff_cache/
|
||||
|
||||
# IDEs
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Model files (typically large)
|
||||
*.pt
|
||||
*.pth
|
||||
*.bin
|
||||
*.onnx
|
||||
*.safetensors
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
test_api_client.py
|
||||
201
packages/GLiNER2/LICENSE
Normal file
201
packages/GLiNER2/LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
1031
packages/GLiNER2/README.md
Normal file
1031
packages/GLiNER2/README.md
Normal file
File diff suppressed because it is too large
Load Diff
79
packages/GLiNER2/RELEASE.md
Normal file
79
packages/GLiNER2/RELEASE.md
Normal file
@ -0,0 +1,79 @@
|
||||
# PyPI Release Guide for GLiNER2
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- [ ] Python 3.8+ installed
|
||||
- [ ] PyPI account with API token configured
|
||||
- [ ] Write access to the repository
|
||||
|
||||
## Release Steps
|
||||
|
||||
### 1. Update Version
|
||||
|
||||
Update version in `gliner2/__init__.py`:
|
||||
```python
|
||||
__version__ = "1.0.1" # New version
|
||||
```
|
||||
|
||||
### 2. Build Package
|
||||
|
||||
```bash
|
||||
# Install build tools
|
||||
pip install build twine
|
||||
|
||||
# Clean previous builds
|
||||
rm -rf dist/ build/ *.egg-info/
|
||||
|
||||
# Build package
|
||||
python -m build
|
||||
```
|
||||
|
||||
### 3. Test Build (Optional)
|
||||
|
||||
```bash
|
||||
# Test on TestPyPI first
|
||||
twine upload --repository testpypi dist/*
|
||||
|
||||
# Install and test
|
||||
pip install --index-url https://test.pypi.org/simple/ gliner2
|
||||
```
|
||||
|
||||
### 4. Upload to PyPI
|
||||
|
||||
```bash
|
||||
# Upload to production PyPI
|
||||
twine upload dist/*
|
||||
```
|
||||
|
||||
### 5. Create GitHub Release
|
||||
|
||||
1. Go to GitHub repository → Releases
|
||||
2. Click "Create a new release"
|
||||
3. Tag: `v1.0.1` (matching version)
|
||||
4. Title: `GLiNER2 v1.0.1`
|
||||
5. Description: Summary of changes
|
||||
6. Attach built wheels from `dist/` folder
|
||||
|
||||
### 6. Verify Release
|
||||
|
||||
```bash
|
||||
# Install from PyPI
|
||||
pip install gliner2==1.0.1
|
||||
|
||||
# Test basic functionality
|
||||
python -c "from gliner2 import GLiNER2; print('✓ Import successful')"
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **Authentication error**: Configure PyPI token in `~/.pypirc` or use `--username __token__`
|
||||
- **File exists error**: Version already exists on PyPI, increment version number
|
||||
- **Build fails**: Check `pyproject.toml` dependencies and Python version compatibility
|
||||
|
||||
## Checklist
|
||||
|
||||
- [ ] Version updated in `__init__.py`
|
||||
- [ ] Package builds without errors
|
||||
- [ ] Uploaded to PyPI successfully
|
||||
- [ ] GitHub release created
|
||||
- [ ] Installation verified
|
||||
401
packages/GLiNER2/benchmark_statistical.py
Normal file
401
packages/GLiNER2/benchmark_statistical.py
Normal file
@ -0,0 +1,401 @@
|
||||
"""
|
||||
Statistical benchmark with confidence intervals and p-values.
|
||||
|
||||
Micro-benchmarks: interleaved old/new in same process → paired t-test.
|
||||
End-to-end: saves raw timings to JSON for cross-process Welch's t-test.
|
||||
|
||||
Usage:
|
||||
# Baseline
|
||||
git stash
|
||||
python benchmark_statistical.py --tag baseline --n 300
|
||||
git stash pop
|
||||
|
||||
# Optimized
|
||||
python benchmark_statistical.py --tag optimized --n 300
|
||||
|
||||
# Compare
|
||||
python benchmark_statistical.py --compare baseline optimized
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
import statistics
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from scipy import stats as sp_stats
|
||||
|
||||
|
||||
# ─── Helpers ──────────────────────────────────────────────────────
|
||||
|
||||
def sync():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def ci95(data):
|
||||
"""95% CI half-width using t-distribution."""
|
||||
n = len(data)
|
||||
if n < 2:
|
||||
return 0.0
|
||||
se = statistics.stdev(data) / math.sqrt(n)
|
||||
t_crit = sp_stats.t.ppf(0.975, df=n - 1)
|
||||
return t_crit * se
|
||||
|
||||
|
||||
def collect(fn, n_warmup, n_iter):
|
||||
"""Run fn with warmup, return list of times in ms."""
|
||||
for _ in range(n_warmup):
|
||||
fn()
|
||||
sync()
|
||||
times = []
|
||||
for _ in range(n_iter):
|
||||
sync()
|
||||
t0 = time.perf_counter()
|
||||
fn()
|
||||
sync()
|
||||
times.append((time.perf_counter() - t0) * 1000)
|
||||
return times
|
||||
|
||||
|
||||
def paired_test(old_times, new_times):
|
||||
"""Paired t-test on matched samples. Returns (t_stat, p_value, mean_diff, ci95_diff)."""
|
||||
diffs = [o - n for o, n in zip(old_times, new_times)]
|
||||
n = len(diffs)
|
||||
mean_d = statistics.mean(diffs)
|
||||
se_d = statistics.stdev(diffs) / math.sqrt(n)
|
||||
t_stat = mean_d / se_d if se_d > 0 else 0
|
||||
p_val = 2 * sp_stats.t.sf(abs(t_stat), df=n - 1)
|
||||
hw = ci95(diffs)
|
||||
return t_stat, p_val, mean_d, hw
|
||||
|
||||
|
||||
def welch_test(a, b):
|
||||
"""Welch's t-test (unequal variance). Returns (t_stat, p_value)."""
|
||||
t_stat, p_val = sp_stats.ttest_ind(a, b, equal_var=False)
|
||||
return t_stat, p_val
|
||||
|
||||
|
||||
def fmt_p(p):
|
||||
if p < 0.001:
|
||||
return f"{p:.2e}"
|
||||
return f"{p:.4f}"
|
||||
|
||||
|
||||
# ─── End-to-end benchmark ────────────────────────────────────────
|
||||
|
||||
def run_e2e(n_iter, n_warmup):
|
||||
"""Run end-to-end scenarios, return dict of {name: [times]}."""
|
||||
from gliner2 import GLiNER2
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
text1 = "Apple CEO Tim Cook announced the iPhone 15 launch in Cupertino on September 12, 2023."
|
||||
ents = ["company", "person", "product", "location", "date"]
|
||||
texts8 = [
|
||||
"Apple CEO Tim Cook announced the iPhone 15 launch in Cupertino.",
|
||||
"Google's Sundar Pichai spoke at the conference in Mountain View.",
|
||||
"Microsoft released Windows 11 in Redmond last year.",
|
||||
"Amazon founder Jeff Bezos invested in Blue Origin in Seattle.",
|
||||
"Tesla CEO Elon Musk unveiled the Cybertruck at the Fremont factory.",
|
||||
"Meta's Mark Zuckerberg presented Quest 3 in Menlo Park.",
|
||||
"NVIDIA's Jensen Huang showcased the H100 GPU at GTC in San Jose.",
|
||||
"OpenAI CEO Sam Altman launched GPT-4 in San Francisco.",
|
||||
]
|
||||
long_text = (
|
||||
"Apple Inc., headquartered in Cupertino, California, is a multinational technology company "
|
||||
"founded by Steve Jobs, Steve Wozniak, and Ronald Wayne in April 1976. The company designs, "
|
||||
"develops, and sells consumer electronics, computer software, and online services. Tim Cook "
|
||||
"has served as CEO since August 2011. Apple's main products include the iPhone, iPad, Mac, "
|
||||
"Apple Watch, and AirPods. The company also operates services including the App Store, "
|
||||
"Apple Music, iCloud, and Apple TV Plus. In 2023, Apple reported annual revenue of $383 "
|
||||
"billion, making it the world's largest technology company by revenue. The company employs "
|
||||
"over 160,000 people worldwide."
|
||||
)
|
||||
ents6 = ["company", "person", "product", "location", "date", "monetary_value"]
|
||||
text_struct = "John Smith, aged 35, is a software engineer at Google in Mountain View."
|
||||
schema_struct = model.create_schema()
|
||||
schema_struct.structure("person").field("name").field("age").field("job_title").field("company").field("location")
|
||||
text_rel = "Apple CEO Tim Cook announced the iPhone 15 launch in Cupertino on September 12."
|
||||
rels = ["CEO_of", "located_in", "announced_on"]
|
||||
|
||||
results = OrderedDict()
|
||||
scenarios = [
|
||||
("single_entity", lambda: model.extract_entities(text1, ents)),
|
||||
("single_structure", lambda: model.extract(text_struct, schema_struct)),
|
||||
("single_relation", lambda: model.extract_relations(text_rel, rels)),
|
||||
("batch8_entity", lambda: model.batch_extract_entities(texts8, ents, batch_size=8)),
|
||||
("long_text_entity", lambda: model.extract_entities(long_text, ents6)),
|
||||
]
|
||||
|
||||
for name, fn in scenarios:
|
||||
print(f" Running {name} (n={n_iter})...", end=" ", flush=True)
|
||||
times = collect(fn, n_warmup, n_iter)
|
||||
results[name] = times
|
||||
m, hw = statistics.mean(times), ci95(times)
|
||||
print(f"{m:.2f} ± {hw:.2f} ms")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ─── Micro-benchmarks (interleaved old/new) ──────────────────────
|
||||
|
||||
def run_micro(n_iter, n_warmup):
|
||||
"""Run micro-benchmarks with interleaved old/new for paired comparison."""
|
||||
import copy
|
||||
from gliner2 import GLiNER2
|
||||
from gliner2.training.trainer import ExtractorCollator
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
tokenizer = model.processor.tokenizer
|
||||
|
||||
results = OrderedDict()
|
||||
|
||||
# --- OPT-1: Token ID lookup ---
|
||||
special_set_str = {"[P]", "[C]", "[E]", "[R]", "[L]"}
|
||||
special_ids = frozenset(tokenizer.convert_tokens_to_ids(t) for t in special_set_str)
|
||||
dummy_ids = list(range(200))
|
||||
|
||||
def opt1_old():
|
||||
for tid in dummy_ids:
|
||||
tok = tokenizer.convert_ids_to_tokens(tid)
|
||||
_ = tok in special_set_str
|
||||
|
||||
def opt1_new():
|
||||
for tid in dummy_ids:
|
||||
_ = tid in special_ids
|
||||
|
||||
print(" OPT-1 Token ID lookup...", end=" ", flush=True)
|
||||
old_t, new_t = _interleaved(opt1_old, opt1_new, n_warmup, n_iter)
|
||||
results["OPT-1 Token ID lookup"] = {"old": old_t, "new": new_t}
|
||||
_print_paired(old_t, new_t)
|
||||
|
||||
# --- OPT-3: Avoid retokenization ---
|
||||
test_text = "Apple CEO Tim Cook announced the iPhone 15 launch in Cupertino on September 12."
|
||||
dummy_map = list(range(15))
|
||||
|
||||
def opt3_old():
|
||||
return len(model.processor._tokenize_text(test_text))
|
||||
|
||||
def opt3_new():
|
||||
return len(dummy_map)
|
||||
|
||||
print(" OPT-3 Avoid retokenization...", end=" ", flush=True)
|
||||
old_t, new_t = _interleaved(opt3_old, opt3_new, n_warmup, n_iter)
|
||||
results["OPT-3 Avoid retokenization"] = {"old": old_t, "new": new_t}
|
||||
_print_paired(old_t, new_t)
|
||||
|
||||
# --- OPT-4: Deepcopy ---
|
||||
schema_dict = {
|
||||
"json_structures": [{"person": {"name": "", "age": "", "job": ""}}],
|
||||
"entities": {"company": "", "location": ""},
|
||||
"relations": [], "classifications": [],
|
||||
}
|
||||
record = {"text": "Apple CEO Tim Cook announced iPhone 15." * 3, "schema": schema_dict}
|
||||
|
||||
def opt4_old():
|
||||
return copy.deepcopy(record)
|
||||
|
||||
def opt4_new():
|
||||
return {"text": record["text"], "schema": copy.deepcopy(record["schema"])}
|
||||
|
||||
print(" OPT-4 Deepcopy...", end=" ", flush=True)
|
||||
old_t, new_t = _interleaved(opt4_old, opt4_new, n_warmup, n_iter)
|
||||
results["OPT-4 Deepcopy"] = {"old": old_t, "new": new_t}
|
||||
_print_paired(old_t, new_t)
|
||||
|
||||
# --- OPT-6: Token cache ---
|
||||
special_tokens = ["[SEP_STRUCT]", "[SEP_TEXT]", "[P]", "[C]", "[E]", "[R]", "[L]",
|
||||
"[EXAMPLE]", "[OUTPUT]", "[DESCRIPTION]", "(", ")", ",", "|"]
|
||||
cache = {tok: tokenizer.tokenize(tok) for tok in special_tokens}
|
||||
test_tokens = special_tokens * 10
|
||||
|
||||
def opt6_old():
|
||||
for tok in test_tokens:
|
||||
tokenizer.tokenize(tok)
|
||||
|
||||
def opt6_new():
|
||||
for tok in test_tokens:
|
||||
if tok in cache:
|
||||
_ = cache[tok]
|
||||
else:
|
||||
tokenizer.tokenize(tok)
|
||||
|
||||
print(" OPT-6 Token cache...", end=" ", flush=True)
|
||||
old_t, new_t = _interleaved(opt6_old, opt6_new, n_warmup, n_iter)
|
||||
results["OPT-6 Token cache"] = {"old": old_t, "new": new_t}
|
||||
_print_paired(old_t, new_t)
|
||||
|
||||
# --- OPT-12: Skip DataLoader ---
|
||||
collator = ExtractorCollator(model.processor, is_training=False)
|
||||
text_norm = "Apple CEO Tim Cook announced the iPhone 15 launch in Cupertino on September 12, 2023."
|
||||
schema_e = model.create_schema().entities(["company", "person", "product", "location", "date"])
|
||||
sd = schema_e.build()
|
||||
for c in sd.get("classifications", []):
|
||||
c.setdefault("true_label", ["N/A"])
|
||||
small_dataset = [(text_norm, sd)]
|
||||
|
||||
def opt12_old():
|
||||
loader = DataLoader(small_dataset, batch_size=8, shuffle=False,
|
||||
num_workers=0, collate_fn=collator)
|
||||
return list(loader)
|
||||
|
||||
def opt12_new():
|
||||
return [collator(small_dataset)]
|
||||
|
||||
print(" OPT-12 Skip DataLoader...", end=" ", flush=True)
|
||||
old_t, new_t = _interleaved(opt12_old, opt12_new, n_warmup, n_iter)
|
||||
results["OPT-12 Skip DataLoader"] = {"old": old_t, "new": new_t}
|
||||
_print_paired(old_t, new_t)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _interleaved(old_fn, new_fn, n_warmup, n_iter):
|
||||
"""Run old/new interleaved to eliminate ordering effects. Returns paired lists."""
|
||||
# Warmup both
|
||||
for _ in range(n_warmup):
|
||||
old_fn()
|
||||
new_fn()
|
||||
sync()
|
||||
|
||||
old_times = []
|
||||
new_times = []
|
||||
for _ in range(n_iter):
|
||||
# Randomize order each iteration to eliminate systematic bias
|
||||
if random.random() < 0.5:
|
||||
sync(); t0 = time.perf_counter(); old_fn(); sync()
|
||||
old_times.append((time.perf_counter() - t0) * 1000)
|
||||
sync(); t0 = time.perf_counter(); new_fn(); sync()
|
||||
new_times.append((time.perf_counter() - t0) * 1000)
|
||||
else:
|
||||
sync(); t0 = time.perf_counter(); new_fn(); sync()
|
||||
new_times.append((time.perf_counter() - t0) * 1000)
|
||||
sync(); t0 = time.perf_counter(); old_fn(); sync()
|
||||
old_times.append((time.perf_counter() - t0) * 1000)
|
||||
|
||||
return old_times, new_times
|
||||
|
||||
|
||||
def _print_paired(old_t, new_t):
|
||||
m_old, m_new = statistics.mean(old_t), statistics.mean(new_t)
|
||||
t_stat, p_val, mean_diff, hw = paired_test(old_t, new_t)
|
||||
speedup = m_old / m_new if m_new > 0 else float('inf')
|
||||
print(f"{m_old:.4f} -> {m_new:.4f} ms ({speedup:.1f}x) "
|
||||
f"diff={mean_diff:.4f}±{hw:.4f}ms p={fmt_p(p_val)}")
|
||||
|
||||
|
||||
# ─── Compare mode ────────────────────────────────────────────────
|
||||
|
||||
def compare(baseline_path, optimized_path):
|
||||
"""Compare two end-to-end result files with Welch's t-test."""
|
||||
with open(baseline_path) as f:
|
||||
baseline = json.load(f)
|
||||
with open(optimized_path) as f:
|
||||
optimized = json.load(f)
|
||||
|
||||
print(f"\nBaseline: {baseline_path} (device={baseline['device']}, n={baseline.get('n', '?')})")
|
||||
print(f"Optimized: {optimized_path} (device={optimized['device']}, n={optimized.get('n', '?')})")
|
||||
|
||||
print(f"\n{'Scenario':<25} {'Baseline':>18} {'Optimized':>18} {'Diff':>14} {'Speedup':>8} {'p-value':>10}")
|
||||
print("=" * 100)
|
||||
|
||||
for name in baseline["e2e"]:
|
||||
b = baseline["e2e"][name]
|
||||
o = optimized["e2e"][name]
|
||||
|
||||
m_b, ci_b = statistics.mean(b), ci95(b)
|
||||
m_o, ci_o = statistics.mean(o), ci95(o)
|
||||
diff = m_b - m_o
|
||||
diff_ci = math.sqrt(ci_b**2 + ci_o**2) # approximate CI of difference
|
||||
speedup = m_b / m_o if m_o > 0 else float('inf')
|
||||
t_stat, p_val = welch_test(b, o)
|
||||
|
||||
sig = "*" if p_val < 0.05 else " "
|
||||
if p_val < 0.01:
|
||||
sig = "**"
|
||||
if p_val < 0.001:
|
||||
sig = "***"
|
||||
|
||||
print(f"{name:<25} {m_b:>7.2f}±{ci_b:>5.2f}ms {m_o:>7.2f}±{ci_o:>5.2f}ms "
|
||||
f"{diff:>+6.2f}±{diff_ci:>4.2f}ms {speedup:>7.3f}x {fmt_p(p_val):>9}{sig}")
|
||||
|
||||
# Micro-benchmarks (if present in optimized)
|
||||
if "micro" in optimized:
|
||||
print(f"\n{'Component':<30} {'Old':>16} {'New':>16} {'Diff (paired)':>18} {'Speedup':>8} {'p-value':>10}")
|
||||
print("=" * 105)
|
||||
|
||||
for name, data in optimized["micro"].items():
|
||||
old_t = data["old"]
|
||||
new_t = data["new"]
|
||||
m_old, ci_old = statistics.mean(old_t), ci95(old_t)
|
||||
m_new, ci_new = statistics.mean(new_t), ci95(new_t)
|
||||
t_stat, p_val, mean_diff, hw = paired_test(old_t, new_t)
|
||||
speedup = m_old / m_new if m_new > 0 else float('inf')
|
||||
|
||||
sig = "*" if p_val < 0.05 else " "
|
||||
if p_val < 0.01: sig = "**"
|
||||
if p_val < 0.001: sig = "***"
|
||||
|
||||
print(f"{name:<30} {m_old:>6.4f}±{ci_old:>6.4f}ms {m_new:>6.4f}±{ci_new:>6.4f}ms "
|
||||
f"{mean_diff:>+7.4f}±{hw:>6.4f}ms {speedup:>7.1f}x {fmt_p(p_val):>9}{sig}")
|
||||
|
||||
|
||||
# ─── Main ────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tag", help="Tag for this run (baseline or optimized)")
|
||||
parser.add_argument("--n", type=int, default=300, help="Iterations per scenario")
|
||||
parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations")
|
||||
parser.add_argument("--compare", nargs=2, metavar=("BASELINE", "OPTIMIZED"),
|
||||
help="Compare two result files")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.compare:
|
||||
compare(
|
||||
f"bench_stats_{args.compare[0]}.json",
|
||||
f"bench_stats_{args.compare[1]}.json"
|
||||
)
|
||||
return
|
||||
|
||||
if not args.tag:
|
||||
parser.error("--tag is required (or use --compare)")
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Device: {device}")
|
||||
print(f"Iterations: {args.n}, Warmup: {args.warmup}\n")
|
||||
|
||||
output = {"tag": args.tag, "device": device, "n": args.n}
|
||||
|
||||
# End-to-end
|
||||
print("END-TO-END BENCHMARKS")
|
||||
print("-" * 60)
|
||||
e2e = run_e2e(args.n, args.warmup)
|
||||
output["e2e"] = e2e
|
||||
|
||||
# Micro-benchmarks (only meaningful for optimized run since we inline both versions)
|
||||
print("\nCOMPONENT MICRO-BENCHMARKS (interleaved old/new)")
|
||||
print("-" * 60)
|
||||
micro = run_micro(args.n, args.warmup)
|
||||
output["micro"] = {k: v for k, v in micro.items()}
|
||||
|
||||
out_path = f"bench_stats_{args.tag}.json"
|
||||
with open(out_path, "w") as f:
|
||||
json.dump(output, f)
|
||||
print(f"\nRaw timings saved to {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
23
packages/GLiNER2/gliner2/__init__.py
Normal file
23
packages/GLiNER2/gliner2/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
__version__ = "1.2.4"
|
||||
|
||||
from .inference.engine import GLiNER2, RegexValidator
|
||||
from .model import Extractor, ExtractorConfig
|
||||
from .api_client import (
|
||||
GLiNER2API,
|
||||
GLiNER2APIError,
|
||||
AuthenticationError,
|
||||
ValidationError,
|
||||
ServerError,
|
||||
)
|
||||
from .training.lora import (
|
||||
LoRAConfig,
|
||||
LoRAAdapterConfig,
|
||||
LoRALayer,
|
||||
load_lora_adapter,
|
||||
save_lora_adapter,
|
||||
unload_lora_adapter,
|
||||
has_lora_adapter,
|
||||
apply_lora_to_model,
|
||||
merge_lora_weights,
|
||||
unmerge_lora_weights,
|
||||
)
|
||||
989
packages/GLiNER2/gliner2/api_client.py
Normal file
989
packages/GLiNER2/gliner2/api_client.py
Normal file
@ -0,0 +1,989 @@
|
||||
"""
|
||||
GLiNER2 API Client
|
||||
|
||||
This module provides an API-based wrapper for GLiNER2 that mirrors the local
|
||||
model interface. It allows seamless switching between local and API-based
|
||||
inference.
|
||||
|
||||
Usage:
|
||||
>>> from gliner2 import GLiNER2
|
||||
>>>
|
||||
>>> # Load from API (uses environment variable for API key)
|
||||
>>> extractor = GLiNER2.from_api()
|
||||
>>>
|
||||
>>> # Use exactly like local model
|
||||
>>> results = extractor.extract_entities(
|
||||
... "Apple released iPhone 15 in September 2023.",
|
||||
... ["company", "product", "date"]
|
||||
... )
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import logging
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional, Union, Literal
|
||||
from urllib.parse import urljoin
|
||||
from urllib3.util import Retry
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GLiNER2APIError(Exception):
|
||||
"""Base exception for GLiNER2 API errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: Optional[int] = None, response_data: Optional[Dict] = None):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.response_data = response_data
|
||||
|
||||
|
||||
class AuthenticationError(GLiNER2APIError):
|
||||
"""Raised when API key is invalid or expired."""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(GLiNER2APIError):
|
||||
"""Raised when request data is invalid."""
|
||||
pass
|
||||
|
||||
|
||||
class ServerError(GLiNER2APIError):
|
||||
"""Raised when server encounters an error."""
|
||||
pass
|
||||
|
||||
|
||||
class StructureBuilderAPI:
|
||||
"""
|
||||
Builder for structured data schemas for API-based extraction.
|
||||
|
||||
This mirrors the interface of StructureBuilder from the local model.
|
||||
"""
|
||||
|
||||
def __init__(self, schema: 'SchemaAPI', parent: str):
|
||||
self.schema = schema
|
||||
self.parent = parent
|
||||
self.fields = OrderedDict()
|
||||
self.field_order = []
|
||||
self._finished = False
|
||||
|
||||
def field(
|
||||
self,
|
||||
name: str,
|
||||
dtype: Literal["str", "list"] = "list",
|
||||
choices: Optional[List[str]] = None,
|
||||
description: Optional[str] = None,
|
||||
threshold: Optional[float] = None,
|
||||
validators: Optional[List] = None
|
||||
) -> 'StructureBuilderAPI':
|
||||
"""Add a field to the structured data."""
|
||||
# Warn if validators are used (not supported in API mode)
|
||||
if validators:
|
||||
warnings.warn(
|
||||
f"Field '{name}': RegexValidator is not supported in API mode. "
|
||||
"Validators will be ignored. Use local model for regex-based filtering.",
|
||||
UserWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
self.fields[name] = {
|
||||
"dtype": dtype,
|
||||
"choices": choices,
|
||||
"description": description,
|
||||
"threshold": threshold
|
||||
}
|
||||
self.field_order.append(name)
|
||||
return self
|
||||
|
||||
def _auto_finish(self):
|
||||
"""Automatically finish this structure when needed."""
|
||||
if not self._finished:
|
||||
# Convert fields to API format
|
||||
# Use dict format if any field has threshold or choices (advanced features)
|
||||
# Otherwise use simple string format for backwards compatibility
|
||||
field_specs = []
|
||||
for name in self.field_order:
|
||||
config = self.fields[name]
|
||||
|
||||
# Check if advanced features are used
|
||||
has_threshold = config.get('threshold') is not None
|
||||
has_choices = config.get('choices') is not None
|
||||
|
||||
if has_threshold or has_choices:
|
||||
# Use dict format for advanced features
|
||||
field_dict = {"name": name, "dtype": config['dtype']}
|
||||
if config.get('description'):
|
||||
field_dict["description"] = config['description']
|
||||
if has_threshold:
|
||||
field_dict["threshold"] = config['threshold']
|
||||
if has_choices:
|
||||
field_dict["choices"] = config['choices']
|
||||
field_specs.append(field_dict)
|
||||
else:
|
||||
# Use simple string format: "name::type::description"
|
||||
spec = f"{name}::{config['dtype']}"
|
||||
if config.get('description'):
|
||||
spec += f"::{config['description']}"
|
||||
field_specs.append(spec)
|
||||
|
||||
self.schema._structures[self.parent] = field_specs
|
||||
self._finished = True
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Auto-finish when any schema method is called."""
|
||||
if hasattr(self.schema, name):
|
||||
self._auto_finish()
|
||||
return getattr(self.schema, name)
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
|
||||
|
||||
class SchemaAPI:
|
||||
"""Schema builder for API-based extraction tasks."""
|
||||
|
||||
def __init__(self):
|
||||
self._entities = None
|
||||
self._entity_dtype = "list"
|
||||
self._entity_threshold = None
|
||||
self._classifications = {}
|
||||
self._structures = {}
|
||||
self._relations = None
|
||||
self._relation_threshold = None
|
||||
self._active_structure_builder = None
|
||||
|
||||
def entities(
|
||||
self,
|
||||
entity_types: Union[str, List[str], Dict[str, Union[str, Dict]]],
|
||||
dtype: Literal["str", "list"] = "list",
|
||||
threshold: Optional[float] = None
|
||||
) -> 'SchemaAPI':
|
||||
"""Add entity extraction task."""
|
||||
if self._active_structure_builder:
|
||||
self._active_structure_builder._auto_finish()
|
||||
self._active_structure_builder = None
|
||||
|
||||
# Normalize to list or dict
|
||||
if isinstance(entity_types, str):
|
||||
self._entities = [entity_types]
|
||||
elif isinstance(entity_types, list):
|
||||
self._entities = entity_types
|
||||
elif isinstance(entity_types, dict):
|
||||
self._entities = entity_types
|
||||
|
||||
self._entity_dtype = dtype
|
||||
self._entity_threshold = threshold
|
||||
return self
|
||||
|
||||
def classification(
|
||||
self,
|
||||
task: str,
|
||||
labels: Union[List[str], Dict[str, str]],
|
||||
multi_label: bool = False,
|
||||
cls_threshold: float = 0.5,
|
||||
**kwargs
|
||||
) -> 'SchemaAPI':
|
||||
"""Add a text classification task."""
|
||||
if self._active_structure_builder:
|
||||
self._active_structure_builder._auto_finish()
|
||||
self._active_structure_builder = None
|
||||
|
||||
# Parse labels
|
||||
if isinstance(labels, dict):
|
||||
label_names = list(labels.keys())
|
||||
else:
|
||||
label_names = labels
|
||||
|
||||
self._classifications[task] = {
|
||||
"labels": label_names,
|
||||
"multi_label": multi_label,
|
||||
"cls_threshold": cls_threshold
|
||||
}
|
||||
return self
|
||||
|
||||
def structure(self, name: str) -> StructureBuilderAPI:
|
||||
"""Start building a structured data schema."""
|
||||
if self._active_structure_builder:
|
||||
self._active_structure_builder._auto_finish()
|
||||
|
||||
self._active_structure_builder = StructureBuilderAPI(self, name)
|
||||
return self._active_structure_builder
|
||||
|
||||
def relations(
|
||||
self,
|
||||
relation_types: Union[str, List[str], Dict[str, Union[str, Dict]]],
|
||||
threshold: Optional[float] = None
|
||||
) -> 'SchemaAPI':
|
||||
"""
|
||||
Add relation extraction task.
|
||||
|
||||
Args:
|
||||
relation_types: Relation types to extract. Can be:
|
||||
- str: Single relation type
|
||||
- List[str]: Multiple relation types
|
||||
- Dict[str, str]: Relation types with descriptions
|
||||
- Dict[str, Dict]: Relation types with full configuration
|
||||
threshold: Default confidence threshold for relations.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
if self._active_structure_builder:
|
||||
self._active_structure_builder._auto_finish()
|
||||
self._active_structure_builder = None
|
||||
|
||||
# Normalize to list or dict
|
||||
if isinstance(relation_types, str):
|
||||
self._relations = [relation_types]
|
||||
elif isinstance(relation_types, list):
|
||||
self._relations = relation_types
|
||||
elif isinstance(relation_types, dict):
|
||||
self._relations = relation_types
|
||||
|
||||
self._relation_threshold = threshold
|
||||
return self
|
||||
|
||||
def build(self) -> Dict[str, Any]:
|
||||
"""Build the schema for API request."""
|
||||
if self._active_structure_builder:
|
||||
self._active_structure_builder._auto_finish()
|
||||
self._active_structure_builder = None
|
||||
|
||||
schema = {}
|
||||
|
||||
if self._entities is not None:
|
||||
schema["entities"] = self._entities
|
||||
schema["entity_dtype"] = self._entity_dtype
|
||||
if self._entity_threshold is not None:
|
||||
schema["entity_threshold"] = self._entity_threshold
|
||||
|
||||
if self._classifications:
|
||||
schema["classifications"] = self._classifications
|
||||
|
||||
if self._structures:
|
||||
schema["structures"] = self._structures
|
||||
|
||||
if self._relations is not None:
|
||||
schema["relations"] = self._relations
|
||||
if self._relation_threshold is not None:
|
||||
schema["relation_threshold"] = self._relation_threshold
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
class GLiNER2API:
|
||||
"""
|
||||
API-based GLiNER2 client that mirrors the local model interface.
|
||||
|
||||
This class provides the same methods as GLiNER2 but makes HTTP requests
|
||||
to the API endpoint instead of running local inference.
|
||||
|
||||
Attributes:
|
||||
api_key: API authentication key
|
||||
base_url: API base URL
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries for failed requests
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.fastino.ai"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
):
|
||||
"""
|
||||
Initialize the GLiNER2 API client.
|
||||
|
||||
Args:
|
||||
api_key: API authentication key. If not provided, reads from
|
||||
PIONEER_API_KEY environment variable.
|
||||
api_base_url: Override the default API base URL.
|
||||
timeout: Request timeout in seconds.
|
||||
max_retries: Maximum number of retries for failed requests.
|
||||
|
||||
Raises:
|
||||
ValueError: If no API key is provided and PIONEER_API_KEY is not set.
|
||||
"""
|
||||
# Read API key from environment if not provided
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("PIONEER_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"API key must be provided either as an argument or via "
|
||||
"PIONEER_API_KEY environment variable"
|
||||
)
|
||||
|
||||
self.api_key = api_key
|
||||
self.base_url = api_base_url or os.environ.get(
|
||||
"GLINER2_API_BASE_URL", self.DEFAULT_BASE_URL
|
||||
)
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
# Setup HTTP session with retry logic
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({
|
||||
"X-API-Key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
})
|
||||
|
||||
# Configure retry strategy
|
||||
retry_strategy = Retry(
|
||||
total=max_retries,
|
||||
backoff_factor=1, # 1s, 2s, 4s backoff
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
allowed_methods=["POST"],
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
self.session.mount("https://", adapter)
|
||||
self.session.mount("http://", adapter)
|
||||
|
||||
logger.debug(f"Initialized GLiNER2API for {self.base_url}")
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
task: str,
|
||||
text: Union[str, List[str]],
|
||||
schema: Union[List[str], Dict],
|
||||
threshold: float = 0.5,
|
||||
include_confidence: bool = False,
|
||||
include_spans: bool = False,
|
||||
format_results: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Make an HTTP request to the GLiNER-2 API.
|
||||
|
||||
Args:
|
||||
task: Task type (extract_entities, classify_text, extract_json, schema)
|
||||
text: Text to process (string or list for batch)
|
||||
schema: Schema for extraction
|
||||
threshold: Confidence threshold
|
||||
include_confidence: Whether to include confidence scores in results
|
||||
include_spans: Whether to include character-level start/end positions
|
||||
format_results: Whether to format results (False for raw extraction data)
|
||||
|
||||
Returns:
|
||||
API response result
|
||||
|
||||
Raises:
|
||||
GLiNER2APIError: If request fails
|
||||
"""
|
||||
# Ensure base_url ends with / for proper joining
|
||||
base = self.base_url.rstrip('/') + '/'
|
||||
url = urljoin(base, "gliner-2")
|
||||
|
||||
payload = {
|
||||
"task": task,
|
||||
"text": text,
|
||||
"schema": schema,
|
||||
"threshold": threshold,
|
||||
"include_confidence": include_confidence,
|
||||
"include_spans": include_spans,
|
||||
"format_results": format_results,
|
||||
}
|
||||
|
||||
logger.debug(f"Making POST request to {url}")
|
||||
|
||||
try:
|
||||
response = self.session.post(
|
||||
url,
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
logger.debug(f"Response status: {response.status_code}")
|
||||
|
||||
# Handle different error codes
|
||||
if response.status_code == 401:
|
||||
error_data = response.json() if response.content else None
|
||||
error_msg = (
|
||||
error_data.get("detail", "Invalid or expired API key")
|
||||
if error_data else "Invalid or expired API key"
|
||||
)
|
||||
raise AuthenticationError(error_msg, response_data=error_data)
|
||||
|
||||
elif response.status_code in (400, 422):
|
||||
error_data = response.json() if response.content else None
|
||||
error_msg = (
|
||||
error_data.get("detail", "Request validation failed")
|
||||
if error_data else "Request validation failed"
|
||||
)
|
||||
raise ValidationError(
|
||||
error_msg,
|
||||
status_code=response.status_code,
|
||||
response_data=error_data,
|
||||
)
|
||||
|
||||
elif response.status_code >= 500:
|
||||
error_data = response.json() if response.content else None
|
||||
error_msg = (
|
||||
error_data.get("detail", "Server error occurred")
|
||||
if error_data else "Server error occurred"
|
||||
)
|
||||
raise ServerError(
|
||||
error_msg,
|
||||
status_code=response.status_code,
|
||||
response_data=error_data,
|
||||
)
|
||||
|
||||
elif not response.ok:
|
||||
error_data = response.json() if response.content else None
|
||||
error_msg = (
|
||||
error_data.get("detail", f"Request failed with status {response.status_code}")
|
||||
if error_data else f"Request failed with status {response.status_code}"
|
||||
)
|
||||
raise GLiNER2APIError(
|
||||
error_msg,
|
||||
status_code=response.status_code,
|
||||
response_data=error_data,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
return data.get("result", data)
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
raise GLiNER2APIError(f"Request timed out after {self.timeout}s")
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise GLiNER2APIError(f"Connection error: {str(e)}")
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise GLiNER2APIError(f"Request failed: {str(e)}")
|
||||
|
||||
def create_schema(self) -> SchemaAPI:
|
||||
"""Create a new schema for defining extraction tasks."""
|
||||
return SchemaAPI()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Entity Extraction Methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def extract_entities(
|
||||
self,
|
||||
text: str,
|
||||
entity_types: Union[List[str], Dict[str, Union[str, Dict]]],
|
||||
threshold: float = 0.5,
|
||||
format_results: bool = True,
|
||||
include_confidence: bool = False,
|
||||
include_spans: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract entities from text.
|
||||
|
||||
Args:
|
||||
text: Input text to extract entities from.
|
||||
entity_types: List of entity types or dict with descriptions.
|
||||
threshold: Minimum confidence threshold.
|
||||
format_results: Whether to format results. If False, returns raw extraction data.
|
||||
include_confidence: Whether to include confidence scores in results.
|
||||
include_spans: Whether to include character-level start/end positions.
|
||||
|
||||
Returns:
|
||||
Dictionary with "entities" key containing extracted entities.
|
||||
If include_confidence=True, entity values include confidence scores.
|
||||
If include_spans=True, entity values include start/end positions.
|
||||
If format_results=False, returns raw extraction data with positions.
|
||||
"""
|
||||
# Normalize entity types to list
|
||||
if isinstance(entity_types, dict):
|
||||
entities = list(entity_types.keys())
|
||||
else:
|
||||
entities = entity_types
|
||||
|
||||
result = self._make_request(
|
||||
task="extract_entities",
|
||||
text=text,
|
||||
schema=entities,
|
||||
threshold=threshold,
|
||||
include_confidence=include_confidence,
|
||||
include_spans=include_spans,
|
||||
format_results=format_results,
|
||||
)
|
||||
|
||||
# Wrap result in expected format if needed (only for formatted results)
|
||||
if format_results and isinstance(result, dict) and "entities" not in result:
|
||||
return {"entities": result}
|
||||
return result
|
||||
|
||||
def batch_extract_entities(
|
||||
self,
|
||||
texts: List[str],
|
||||
entity_types: Union[List[str], Dict[str, Union[str, Dict]]],
|
||||
batch_size: int = 8,
|
||||
threshold: float = 0.5,
|
||||
format_results: bool = True,
|
||||
include_confidence: bool = False,
|
||||
include_spans: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Batch extract entities from multiple texts.
|
||||
|
||||
Args:
|
||||
texts: List of input texts.
|
||||
entity_types: List of entity types or dict with descriptions.
|
||||
batch_size: Batch size (used by API for optimization).
|
||||
threshold: Minimum confidence threshold.
|
||||
format_results: Whether to format results. If False, returns raw extraction data.
|
||||
include_confidence: Whether to include confidence scores.
|
||||
include_spans: Whether to include character-level start/end positions.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with "entities" key.
|
||||
If include_confidence=True, entity values include confidence scores.
|
||||
If include_spans=True, entity values include start/end positions.
|
||||
If format_results=False, returns raw extraction data with positions.
|
||||
"""
|
||||
# Normalize entity types to list
|
||||
if isinstance(entity_types, dict):
|
||||
entities = list(entity_types.keys())
|
||||
else:
|
||||
entities = entity_types
|
||||
|
||||
result = self._make_request(
|
||||
task="extract_entities",
|
||||
text=texts,
|
||||
schema=entities,
|
||||
threshold=threshold,
|
||||
include_confidence=include_confidence,
|
||||
include_spans=include_spans,
|
||||
format_results=format_results,
|
||||
)
|
||||
|
||||
# Ensure result is a list
|
||||
if isinstance(result, dict):
|
||||
return [result]
|
||||
return result
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Text Classification Methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def classify_text(
|
||||
self,
|
||||
text: str,
|
||||
tasks: Dict[str, Union[List[str], Dict[str, Any]]],
|
||||
threshold: float = 0.5,
|
||||
format_results: bool = True,
|
||||
include_confidence: bool = False,
|
||||
include_spans: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Classify text into categories.
|
||||
|
||||
Args:
|
||||
text: Text to classify.
|
||||
tasks: Classification tasks where keys are task names.
|
||||
threshold: Confidence threshold.
|
||||
format_results: Whether to format results. If False, returns raw extraction data.
|
||||
include_confidence: Whether to include confidence scores.
|
||||
include_spans: Whether to include character-level start/end positions.
|
||||
|
||||
Returns:
|
||||
Classification results keyed by task name.
|
||||
If include_confidence=True, results include confidence scores.
|
||||
If format_results=False, returns raw extraction data.
|
||||
"""
|
||||
# Convert tasks to API format
|
||||
# For classify_text task, schema should be {"categories": [...]}
|
||||
# But for multi-task, we need to use the schema task
|
||||
if len(tasks) == 1:
|
||||
# Single task - use classify_text endpoint
|
||||
task_name = list(tasks.keys())[0]
|
||||
task_config = tasks[task_name]
|
||||
|
||||
if isinstance(task_config, dict) and "labels" in task_config:
|
||||
categories = task_config["labels"]
|
||||
else:
|
||||
categories = task_config
|
||||
|
||||
result = self._make_request(
|
||||
task="classify_text",
|
||||
text=text,
|
||||
schema={"categories": categories},
|
||||
threshold=threshold,
|
||||
include_confidence=include_confidence,
|
||||
include_spans=include_spans,
|
||||
format_results=format_results,
|
||||
)
|
||||
|
||||
# Wrap result with task name (only for formatted results)
|
||||
if format_results and isinstance(result, dict) and task_name not in result:
|
||||
return {task_name: result.get("classification", result)}
|
||||
return result
|
||||
else:
|
||||
# Multiple tasks - use schema endpoint
|
||||
schema = {"classifications": tasks}
|
||||
result = self._make_request(
|
||||
task="schema",
|
||||
text=text,
|
||||
schema=schema,
|
||||
threshold=threshold,
|
||||
include_confidence=include_confidence,
|
||||
include_spans=include_spans,
|
||||
format_results=format_results,
|
||||
)
|
||||
return result
|
||||
|
||||
def batch_classify_text(
|
||||
self,
|
||||
texts: List[str],
|
||||
tasks: Dict[str, Union[List[str], Dict[str, Any]]],
|
||||
batch_size: int = 8,
|
||||
threshold: float = 0.5,
|
||||
format_results: bool = True,
|
||||
include_confidence: bool = False,
|
||||
include_spans: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Batch classify multiple texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts to classify.
|
||||
tasks: Classification tasks.
|
||||
batch_size: Batch size.
|
||||
threshold: Confidence threshold.
|
||||
format_results: Whether to format results. If False, returns raw extraction data.
|
||||
include_confidence: Whether to include confidence scores.
|
||||
include_spans: Whether to include character-level start/end positions.
|
||||
|
||||
Returns:
|
||||
List of classification results.
|
||||
If include_confidence=True, results include confidence scores.
|
||||
If format_results=False, returns raw extraction data.
|
||||
"""
|
||||
# Use schema task for batch classification
|
||||
schema = {"classifications": tasks}
|
||||
result = self._make_request(
|
||||
task="schema",
|
||||
text=texts,
|
||||
schema=schema,
|
||||
threshold=threshold,
|
||||
include_confidence=include_confidence,
|
||||
include_spans=include_spans,
|
||||
format_results=format_results,
|
||||
)
|
||||
|
||||
if isinstance(result, dict):
|
||||
return [result]
|
||||
return result
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# JSON Extraction Methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def extract_json(
|
||||
self,
|
||||
text: str,
|
||||
structures: Dict[str, List[str]],
|
||||
threshold: float = 0.5,
|
||||
format_results: bool = True,
|
||||
include_confidence: bool = False,
|
||||
include_spans: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract structured data from text.
|
||||
|
||||
Args:
|
||||
text: Text to extract data from.
|
||||
structures: Structure definitions with field specs.
|
||||
threshold: Minimum confidence threshold.
|
||||
format_results: Whether to format results. If False, returns raw extraction data.
|
||||
include_confidence: Whether to include confidence scores.
|
||||
include_spans: Whether to include character-level start/end positions.
|
||||
|
||||
Returns:
|
||||
Extracted structures keyed by structure name.
|
||||
If include_confidence=True, field values include confidence scores.
|
||||
If include_spans=True, field values include start/end positions.
|
||||
If format_results=False, returns raw extraction data with positions.
|
||||
"""
|
||||
result = self._make_request(
|
||||
task="extract_json",
|
||||
text=text,
|
||||
schema=structures,
|
||||
threshold=threshold,
|
||||
include_confidence=include_confidence,
|
||||
include_spans=include_spans,
|
||||
format_results=format_results,
|
||||
)
|
||||
return result
|
||||
|
||||
def batch_extract_json(
|
||||
self,
|
||||
texts: List[str],
|
||||
structures: Dict[str, List[str]],
|
||||
batch_size: int = 8,
|
||||
threshold: float = 0.5,
|
||||
format_results: bool = True,
|
||||
include_confidence: bool = False,
|
||||
include_spans: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Batch extract structured data from multiple texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts.
|
||||
structures: Structure definitions.
|
||||
batch_size: Batch size.
|
||||
threshold: Confidence threshold.
|
||||
format_results: Whether to format results. If False, returns raw extraction data.
|
||||
include_confidence: Whether to include confidence scores.
|
||||
include_spans: Whether to include character-level start/end positions.
|
||||
|
||||
Returns:
|
||||
List of extracted structures.
|
||||
If include_confidence=True, field values include confidence scores.
|
||||
If include_spans=True, field values include start/end positions.
|
||||
If format_results=False, returns raw extraction data with positions.
|
||||
"""
|
||||
result = self._make_request(
|
||||
task="extract_json",
|
||||
text=texts,
|
||||
schema=structures,
|
||||
threshold=threshold,
|
||||
include_confidence=include_confidence,
|
||||
include_spans=include_spans,
|
||||
format_results=format_results,
|
||||
)
|
||||
|
||||
if isinstance(result, dict):
|
||||
return [result]
|
||||
return result
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Relation Extraction Methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def extract_relations(
|
||||
self,
|
||||
text: str,
|
||||
relation_types: Union[str, List[str], Dict[str, Union[str, Dict]]],
|
||||
threshold: float = 0.5,
|
||||
format_results: bool = True,
|
||||
include_confidence: bool = False,
|
||||
include_spans: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract relations between entities from text.
|
||||
|
||||
Args:
|
||||
text: Input text to extract relations from.
|
||||
relation_types: Relation types to extract. Can be:
|
||||
- str: Single relation type
|
||||
- List[str]: Multiple relation types
|
||||
- Dict[str, str]: Relation types with descriptions
|
||||
- Dict[str, Dict]: Relation types with full configuration
|
||||
threshold: Minimum confidence threshold.
|
||||
format_results: Whether to format results. If False, returns raw extraction data.
|
||||
include_confidence: Whether to include confidence scores in results.
|
||||
include_spans: Whether to include character-level start/end positions.
|
||||
|
||||
Returns:
|
||||
Dictionary with "relation_extraction" key containing extracted relations.
|
||||
Relations are grouped by type with tuples (source, target).
|
||||
Format: {"relation_extraction": {"relation_name": [("source", "target"), ...]}}
|
||||
"""
|
||||
# Build schema with relations
|
||||
schema = self.create_schema().relations(relation_types).build()
|
||||
|
||||
result = self._make_request(
|
||||
task="schema",
|
||||
text=text,
|
||||
schema=schema,
|
||||
threshold=threshold,
|
||||
include_confidence=include_confidence,
|
||||
include_spans=include_spans,
|
||||
format_results=format_results,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def batch_extract_relations(
|
||||
self,
|
||||
texts: List[str],
|
||||
relation_types: Union[str, List[str], Dict[str, Union[str, Dict]]],
|
||||
batch_size: int = 8,
|
||||
threshold: float = 0.5,
|
||||
format_results: bool = True,
|
||||
include_confidence: bool = False,
|
||||
include_spans: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Batch extract relations from multiple texts.
|
||||
|
||||
Args:
|
||||
texts: List of input texts.
|
||||
relation_types: Relation types to extract.
|
||||
batch_size: Batch size (used by API for optimization).
|
||||
threshold: Minimum confidence threshold.
|
||||
format_results: Whether to format results.
|
||||
include_confidence: Whether to include confidence scores.
|
||||
include_spans: Whether to include character-level start/end positions.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with "relation_extraction" key.
|
||||
Format: [{"relation_extraction": {"relation_name": [("source", "target"), ...]}}]
|
||||
"""
|
||||
# Build schema with relations
|
||||
schema = self.create_schema().relations(relation_types).build()
|
||||
|
||||
result = self._make_request(
|
||||
task="schema",
|
||||
text=texts,
|
||||
schema=schema,
|
||||
threshold=threshold,
|
||||
include_confidence=include_confidence,
|
||||
include_spans=include_spans,
|
||||
format_results=format_results,
|
||||
)
|
||||
|
||||
# Ensure result is a list
|
||||
if isinstance(result, dict):
|
||||
return [result]
|
||||
return result
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# General Extraction Methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def extract(
|
||||
self,
|
||||
text: str,
|
||||
schema: Union[SchemaAPI, Dict[str, Any]],
|
||||
threshold: float = 0.5,
|
||||
format_results: bool = True,
|
||||
include_confidence: bool = False,
|
||||
include_spans: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract information from text using a schema.
|
||||
|
||||
Args:
|
||||
text: Input text to extract from.
|
||||
schema: Schema defining what to extract.
|
||||
threshold: Minimum confidence threshold.
|
||||
format_results: Whether to format results. If False, returns raw extraction data.
|
||||
include_confidence: Whether to include confidence scores.
|
||||
include_spans: Whether to include character-level start/end positions.
|
||||
|
||||
Returns:
|
||||
Extraction results organized by task name.
|
||||
If include_confidence=True, values include confidence scores.
|
||||
If include_spans=True, values include start/end positions.
|
||||
If format_results=False, returns raw extraction data with positions.
|
||||
"""
|
||||
# Build schema dict if needed
|
||||
if isinstance(schema, SchemaAPI):
|
||||
schema_dict = schema.build()
|
||||
elif hasattr(schema, 'build'):
|
||||
schema_dict = schema.build()
|
||||
else:
|
||||
schema_dict = schema
|
||||
|
||||
# Validate schema has at least one extraction task
|
||||
has_any_task = any(
|
||||
key in schema_dict
|
||||
for key in ["entities", "classifications", "structures", "relations"]
|
||||
)
|
||||
if not has_any_task:
|
||||
raise ValueError("Schema must contain at least one extraction task")
|
||||
|
||||
# Always use schema task to preserve all metadata (thresholds, dtypes, etc.)
|
||||
return self._make_request(
|
||||
task="schema",
|
||||
text=text,
|
||||
schema=schema_dict,
|
||||
threshold=threshold,
|
||||
include_confidence=include_confidence,
|
||||
include_spans=include_spans,
|
||||
format_results=format_results,
|
||||
)
|
||||
|
||||
def batch_extract(
|
||||
self,
|
||||
texts: List[str],
|
||||
schemas: Union[SchemaAPI, List[SchemaAPI], Dict[str, Any], List[Dict[str, Any]]],
|
||||
batch_size: int = 8,
|
||||
threshold: float = 0.5,
|
||||
format_results: bool = True,
|
||||
include_confidence: bool = False,
|
||||
include_spans: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract information from multiple texts.
|
||||
|
||||
Args:
|
||||
texts: List of input texts.
|
||||
schemas: Single schema for all texts or list of schemas.
|
||||
batch_size: Batch size.
|
||||
threshold: Confidence threshold.
|
||||
format_results: Whether to format results. If False, returns raw extraction data.
|
||||
include_confidence: Whether to include confidence scores.
|
||||
include_spans: Whether to include character-level start/end positions.
|
||||
|
||||
Returns:
|
||||
List of extraction results.
|
||||
If include_confidence=True, values include confidence scores.
|
||||
If include_spans=True, values include start/end positions.
|
||||
If format_results=False, returns raw extraction data with positions.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Handle schema variations
|
||||
if isinstance(schemas, list):
|
||||
if len(schemas) != len(texts):
|
||||
raise ValueError(
|
||||
f"Number of schemas ({len(schemas)}) must match number of texts ({len(texts)})"
|
||||
)
|
||||
# Warn user about multi-schema batch limitation
|
||||
warnings.warn(
|
||||
"Multi-schema batch (different schemas per text) is not natively supported by the API. "
|
||||
"Each text will be processed individually, which may be slower than single-schema batch. "
|
||||
"For better performance, use the same schema for all texts.",
|
||||
UserWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
# Process each text with its schema individually
|
||||
results = []
|
||||
for text, schema in zip(texts, schemas):
|
||||
results.append(self.extract(text, schema, threshold, include_confidence=include_confidence, include_spans=include_spans, format_results=format_results))
|
||||
return results
|
||||
|
||||
# Single schema for all texts
|
||||
if isinstance(schemas, SchemaAPI):
|
||||
schema_dict = schemas.build()
|
||||
elif hasattr(schemas, 'build'):
|
||||
schema_dict = schemas.build()
|
||||
else:
|
||||
schema_dict = schemas
|
||||
|
||||
return self._make_request(
|
||||
task="schema",
|
||||
text=texts,
|
||||
schema=schema_dict,
|
||||
threshold=threshold,
|
||||
include_confidence=include_confidence,
|
||||
include_spans=include_spans,
|
||||
format_results=format_results,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Utility Methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def close(self):
|
||||
"""Close the HTTP session."""
|
||||
self.session.close()
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
|
||||
1
packages/GLiNER2/gliner2/inference/__init__.py
Normal file
1
packages/GLiNER2/gliner2/inference/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .engine import RegexValidator, GLiNER2
|
||||
1458
packages/GLiNER2/gliner2/inference/engine.py
Normal file
1458
packages/GLiNER2/gliner2/inference/engine.py
Normal file
File diff suppressed because it is too large
Load Diff
191
packages/GLiNER2/gliner2/inference/schema_model.py
Normal file
191
packages/GLiNER2/gliner2/inference/schema_model.py
Normal file
@ -0,0 +1,191 @@
|
||||
"""
|
||||
Pydantic models for validating schema input from JSON/dict.
|
||||
|
||||
This module provides validation models for creating GLiNER2 schemas
|
||||
from JSON or dictionary inputs.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
|
||||
class FieldInput(BaseModel):
|
||||
"""Validates a single structure field.
|
||||
|
||||
Args:
|
||||
name: Field name
|
||||
dtype: Data type - 'str' for single value, 'list' for multiple values
|
||||
choices: Optional list of valid choices for classification-style fields
|
||||
description: Optional description of the field
|
||||
"""
|
||||
name: str = Field(..., min_length=1, description="Field name")
|
||||
dtype: Literal["str", "list"] = Field(default="list", description="Data type")
|
||||
choices: Optional[List[str]] = Field(default=None, description="Valid choices")
|
||||
description: Optional[str] = Field(default=None, description="Field description")
|
||||
|
||||
@field_validator('choices')
|
||||
@classmethod
|
||||
def validate_choices(cls, v: Optional[List[str]]) -> Optional[List[str]]:
|
||||
"""Ensure choices list is not empty if provided."""
|
||||
if v is not None and len(v) == 0:
|
||||
raise ValueError("choices must contain at least one option")
|
||||
return v
|
||||
|
||||
|
||||
class StructureInput(BaseModel):
|
||||
"""Validates a structure block.
|
||||
|
||||
Args:
|
||||
fields: List of field definitions
|
||||
"""
|
||||
fields: List[FieldInput] = Field(..., min_length=1, description="List of fields")
|
||||
|
||||
|
||||
class ClassificationInput(BaseModel):
|
||||
"""Validates a classification task.
|
||||
|
||||
Args:
|
||||
task: Task name
|
||||
labels: List of classification labels
|
||||
multi_label: Whether multiple labels can be selected
|
||||
"""
|
||||
task: str = Field(..., min_length=1, description="Task name")
|
||||
labels: List[str] = Field(..., min_length=2, description="Classification labels")
|
||||
multi_label: bool = Field(default=False, description="Multi-label classification")
|
||||
|
||||
@field_validator('labels')
|
||||
@classmethod
|
||||
def validate_labels(cls, v: List[str]) -> List[str]:
|
||||
"""Ensure labels are unique and non-empty."""
|
||||
if len(v) != len(set(v)):
|
||||
raise ValueError("labels must be unique")
|
||||
if any(not label.strip() for label in v):
|
||||
raise ValueError("labels cannot be empty strings")
|
||||
return v
|
||||
|
||||
|
||||
class SchemaInput(BaseModel):
|
||||
"""Root schema validation model.
|
||||
|
||||
Args:
|
||||
entities: List of entity types or dict mapping types to descriptions
|
||||
structures: Dict mapping structure names to structure definitions
|
||||
classifications: List of classification task definitions
|
||||
relations: List of relation types or dict mapping types to config
|
||||
"""
|
||||
entities: Optional[Union[List[str], Dict[str, str]]] = Field(
|
||||
default=None,
|
||||
description="Entity types"
|
||||
)
|
||||
structures: Optional[Dict[str, StructureInput]] = Field(
|
||||
default=None,
|
||||
description="Structure definitions"
|
||||
)
|
||||
classifications: Optional[List[ClassificationInput]] = Field(
|
||||
default=None,
|
||||
description="Classification tasks"
|
||||
)
|
||||
relations: Optional[Union[List[str], Dict[str, Dict[str, Any]]]] = Field(
|
||||
default=None,
|
||||
description="Relation types"
|
||||
)
|
||||
|
||||
@field_validator('entities')
|
||||
@classmethod
|
||||
def validate_entities(
|
||||
cls,
|
||||
v: Optional[Union[List[str], Dict[str, str]]]
|
||||
) -> Optional[Union[List[str], Dict[str, str]]]:
|
||||
"""Validate entities format."""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
if isinstance(v, list):
|
||||
if len(v) == 0:
|
||||
raise ValueError("entities list cannot be empty")
|
||||
if any(not entity.strip() for entity in v):
|
||||
raise ValueError("entity names cannot be empty strings")
|
||||
if len(v) != len(set(v)):
|
||||
raise ValueError("entity names must be unique")
|
||||
elif isinstance(v, dict):
|
||||
if len(v) == 0:
|
||||
raise ValueError("entities dict cannot be empty")
|
||||
if any(not key.strip() for key in v.keys()):
|
||||
raise ValueError("entity names cannot be empty strings")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator('structures')
|
||||
@classmethod
|
||||
def validate_structures(
|
||||
cls,
|
||||
v: Optional[Dict[str, StructureInput]]
|
||||
) -> Optional[Dict[str, StructureInput]]:
|
||||
"""Validate structures format."""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
if len(v) == 0:
|
||||
raise ValueError("structures dict cannot be empty")
|
||||
if any(not key.strip() for key in v.keys()):
|
||||
raise ValueError("structure names cannot be empty strings")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator('classifications')
|
||||
@classmethod
|
||||
def validate_classifications(
|
||||
cls,
|
||||
v: Optional[List[ClassificationInput]]
|
||||
) -> Optional[List[ClassificationInput]]:
|
||||
"""Validate classifications format."""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
if len(v) == 0:
|
||||
raise ValueError("classifications list cannot be empty")
|
||||
|
||||
# Check for duplicate task names
|
||||
task_names = [cls_task.task for cls_task in v]
|
||||
if len(task_names) != len(set(task_names)):
|
||||
raise ValueError("classification task names must be unique")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator('relations')
|
||||
@classmethod
|
||||
def validate_relations(
|
||||
cls,
|
||||
v: Optional[Union[List[str], Dict[str, Dict[str, Any]]]]
|
||||
) -> Optional[Union[List[str], Dict[str, Dict[str, Any]]]]:
|
||||
"""Validate relations format."""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
if isinstance(v, list):
|
||||
if len(v) == 0:
|
||||
raise ValueError("relations list cannot be empty")
|
||||
if any(not rel.strip() for rel in v):
|
||||
raise ValueError("relation names cannot be empty strings")
|
||||
if len(v) != len(set(v)):
|
||||
raise ValueError("relation names must be unique")
|
||||
elif isinstance(v, dict):
|
||||
if len(v) == 0:
|
||||
raise ValueError("relations dict cannot be empty")
|
||||
if any(not key.strip() for key in v.keys()):
|
||||
raise ValueError("relation names cannot be empty strings")
|
||||
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_at_least_one_section(self) -> 'SchemaInput':
|
||||
"""Ensure at least one section is provided."""
|
||||
if all(
|
||||
getattr(self, field) is None
|
||||
for field in ['entities', 'structures', 'classifications', 'relations']
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of entities, structures, classifications, "
|
||||
"or relations must be provided"
|
||||
)
|
||||
return self
|
||||
249
packages/GLiNER2/gliner2/layers.py
Normal file
249
packages/GLiNER2/gliner2/layers.py
Normal file
@ -0,0 +1,249 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def create_mlp(input_dim, intermediate_dims, output_dim, dropout=0.1, activation="gelu", add_layer_norm=False):
|
||||
"""
|
||||
Creates a multi-layer perceptron (MLP) with specified dimensions and activation functions.
|
||||
"""
|
||||
activation_mapping = {
|
||||
"relu": nn.ReLU,
|
||||
"tanh": nn.Tanh,
|
||||
"sigmoid": nn.Sigmoid,
|
||||
"leaky_relu": nn.LeakyReLU,
|
||||
"gelu": nn.GELU
|
||||
}
|
||||
layers = []
|
||||
in_dim = input_dim
|
||||
for dim in intermediate_dims:
|
||||
layers.append(nn.Linear(in_dim, dim))
|
||||
if add_layer_norm:
|
||||
layers.append(nn.LayerNorm(dim))
|
||||
layers.append(activation_mapping[activation]())
|
||||
if dropout > 0:
|
||||
layers.append(nn.Dropout(dropout))
|
||||
in_dim = dim
|
||||
layers.append(nn.Linear(in_dim, output_dim))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class DownscaledTransformer(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_heads=4, num_layers=2, dropout=0.1):
|
||||
"""
|
||||
Initializes a downscaled transformer with specified parameters.
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
|
||||
self.in_projector = nn.Linear(input_size, hidden_size)
|
||||
|
||||
encoder = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_size,
|
||||
nhead=num_heads,
|
||||
dim_feedforward=hidden_size * 2,
|
||||
dropout=dropout,
|
||||
batch_first=True
|
||||
)
|
||||
|
||||
self.transformer = nn.TransformerEncoder(encoder, num_layers=num_layers)
|
||||
|
||||
self.out_projector = create_mlp(
|
||||
input_dim=hidden_size + input_size,
|
||||
intermediate_dims=[input_size, input_size],
|
||||
output_dim=input_size,
|
||||
dropout=0.,
|
||||
activation="relu",
|
||||
add_layer_norm=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): Input tensor of shape (L, M, input_size).
|
||||
Returns:
|
||||
Tensor: Output tensor of shape (L, M, input_size).
|
||||
"""
|
||||
original_x = x
|
||||
# Project input to hidden size.
|
||||
x = self.in_projector(x)
|
||||
# Apply transformer encoder.xx
|
||||
x = self.transformer(x)
|
||||
# Concatenate original input with transformer output.
|
||||
x = torch.cat([x, original_x], dim=-1)
|
||||
# Project back to input size.
|
||||
x = self.out_projector(x)
|
||||
return x
|
||||
|
||||
|
||||
class CountLSTM(nn.Module):
|
||||
def __init__(self, hidden_size, max_count=20):
|
||||
"""
|
||||
Initializes the module with a learned positional embedding for count steps and a GRU.
|
||||
"""
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.max_count = max_count
|
||||
# Learned positional embeddings for count steps: shape (max_count, hidden_size)
|
||||
self.pos_embedding = nn.Embedding(max_count, hidden_size)
|
||||
# Use a GRU layer; input shape is (seq_len, batch, input_size)
|
||||
self.gru = nn.GRU(input_size=hidden_size, hidden_size=hidden_size)
|
||||
# Projector layer: combines GRU output with original embeddings.
|
||||
self.projector = create_mlp(
|
||||
input_dim=hidden_size * 2,
|
||||
intermediate_dims=[hidden_size * 4],
|
||||
output_dim=hidden_size,
|
||||
dropout=0.,
|
||||
activation="relu",
|
||||
add_layer_norm=False
|
||||
)
|
||||
|
||||
def forward(self, pc_emb: torch.Tensor, gold_count_val: int) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
pc_emb (Tensor): Field embeddings of shape (M, hidden_size).
|
||||
gold_count_val (int): Predicted count value (number of steps).
|
||||
Returns:
|
||||
Tensor: Count-aware structure embeddings of shape (gold_count_val, M, hidden_size).
|
||||
"""
|
||||
M, D = pc_emb.shape
|
||||
# Cap gold_count_val by max_count.
|
||||
gold_count_val = min(gold_count_val, self.max_count)
|
||||
# Create a sequence of count indices: shape (gold_count_val,)
|
||||
count_indices = torch.arange(gold_count_val, device=pc_emb.device)
|
||||
# Get positional embeddings for each count: (gold_count_val, hidden_size)
|
||||
pos_seq = self.pos_embedding(count_indices)
|
||||
# Expand pos_seq over the batch dimension: (gold_count_val, M, hidden_size)
|
||||
pos_seq = pos_seq.unsqueeze(1).expand(gold_count_val, M, D)
|
||||
# Initialize the GRU hidden state with the field embeddings.
|
||||
h0 = pc_emb.unsqueeze(0) # shape: (1, M, hidden_size)
|
||||
# Run the GRU over the count sequence.
|
||||
output, _ = self.gru(pos_seq, h0)
|
||||
# Concatenate the GRU outputs with the original field embeddings.
|
||||
return self.projector(torch.cat([output, pc_emb.unsqueeze(0).expand_as(output)], dim=-1))
|
||||
|
||||
|
||||
class CountLSTMv2(nn.Module):
|
||||
def __init__(self, hidden_size, max_count=20):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.max_count = max_count
|
||||
self.pos_embedding = nn.Embedding(max_count, hidden_size)
|
||||
self.gru = nn.GRU(hidden_size, hidden_size)
|
||||
self.transformer = DownscaledTransformer(
|
||||
hidden_size,
|
||||
hidden_size=128,
|
||||
num_heads=4,
|
||||
num_layers=2,
|
||||
dropout=0.1,
|
||||
)
|
||||
|
||||
# NOTE: gold_count_val is now a 0-D Tensor, not a Python int
|
||||
def forward(self, pc_emb: torch.Tensor, gold_count_val: int) -> torch.Tensor:
|
||||
M, D = pc_emb.size() # symbolic sizes
|
||||
|
||||
# clamp without dropping to Python
|
||||
gold_count_val = min(gold_count_val, self.max_count)
|
||||
|
||||
# build the *full* index vector once, then slice – ONNX supports both ops
|
||||
full_idx = torch.arange(self.max_count, device=pc_emb.device)
|
||||
count_idx = full_idx[:gold_count_val] # (gold_count_val,)
|
||||
|
||||
pos_seq = self.pos_embedding(count_idx) # (gold_count_val, D)
|
||||
pos_seq = pos_seq.unsqueeze(1).expand(-1, M, -1) # (gold_count_val, M, D)
|
||||
|
||||
h0 = pc_emb.unsqueeze(0) # (1, M, D)
|
||||
output, _ = self.gru(pos_seq, h0) # (gold_count_val, M, D)
|
||||
|
||||
pc_broadcast = pc_emb.unsqueeze(0).expand_as(output)
|
||||
return self.transformer(output + pc_broadcast)
|
||||
|
||||
|
||||
class CountLSTMoE(nn.Module):
|
||||
"""
|
||||
Count-aware module with a Mixture-of-Experts projector.
|
||||
|
||||
Args
|
||||
----
|
||||
hidden_size : int
|
||||
Model dimensionality (D).
|
||||
max_count : int
|
||||
Maximum # count steps L.
|
||||
n_experts : int, optional
|
||||
Number of FFN experts (default = 4).
|
||||
ffn_mult : int, optional
|
||||
Expansion factor for expert bottleneck (default = 2 → inner = 2 D).
|
||||
dropout : float, optional
|
||||
Drop-out used inside expert FFNs.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
max_count: int = 20,
|
||||
n_experts: int = 4,
|
||||
ffn_mult: int = 2,
|
||||
dropout: float = 0.1):
|
||||
super().__init__()
|
||||
self.hidden_size, self.max_count, self.n_experts = (
|
||||
hidden_size, max_count, n_experts)
|
||||
|
||||
# ───── positional encoding + recurrent core ─────
|
||||
self.pos_embedding = nn.Embedding(max_count, hidden_size)
|
||||
self.gru = nn.GRU(hidden_size, hidden_size)
|
||||
|
||||
# ───── expert parameters (all packed in two tensors) ─────
|
||||
inner = hidden_size * ffn_mult
|
||||
# W1 : [E, D, inner] b1 : [E, inner]
|
||||
self.w1 = nn.Parameter(torch.empty(n_experts, hidden_size, inner))
|
||||
self.b1 = nn.Parameter(torch.zeros(n_experts, inner))
|
||||
# W2 : [E, inner, D] b2 : [E, D]
|
||||
self.w2 = nn.Parameter(torch.empty(n_experts, inner, hidden_size))
|
||||
self.b2 = nn.Parameter(torch.zeros(n_experts, hidden_size))
|
||||
|
||||
# better than default init for large mats
|
||||
nn.init.xavier_uniform_(self.w1)
|
||||
nn.init.xavier_uniform_(self.w2)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# ───── router / gating network ─────
|
||||
self.router = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_size, n_experts), # logits
|
||||
nn.Softmax(dim=-1), # gates sum-to-1
|
||||
)
|
||||
|
||||
# ---------------------------------------------------
|
||||
def forward(self, pc_emb: torch.Tensor, gold_count_val: int) -> torch.Tensor:
|
||||
"""
|
||||
pc_emb : [M, D] field embeddings
|
||||
gold_count_val : int (# count steps to unroll)
|
||||
returns : [L, M, D] count-aware embeddings
|
||||
"""
|
||||
M, D = pc_emb.shape
|
||||
L = min(gold_count_val, self.max_count)
|
||||
|
||||
idx = torch.arange(L, device=pc_emb.device)
|
||||
pos_seq = self.pos_embedding(idx).unsqueeze(1).expand(L, M, D)
|
||||
|
||||
h0 = pc_emb.unsqueeze(0) # [1, M, D]
|
||||
h, _ = self.gru(pos_seq, h0) # [L, M, D]
|
||||
|
||||
# ───── routing / gating ─────
|
||||
gates = self.router(h) # [L, M, E]
|
||||
|
||||
# ───── expert FFN: run *all* experts in parallel ─────
|
||||
# 1st linear
|
||||
x = torch.einsum('lmd,edh->lmeh', h, self.w1) + self.b1 # [L, M, E, inner]
|
||||
x = F.gelu(x)
|
||||
x = self.dropout(x)
|
||||
# 2nd linear
|
||||
x = torch.einsum('lmeh,ehd->lmed', x, self.w2) + self.b2 # [L, M, E, D]
|
||||
|
||||
# ───── mixture weighted by gates ─────
|
||||
out = (gates.unsqueeze(-1) * x).sum(dim=2) # [L, M, D]
|
||||
return out
|
||||
692
packages/GLiNER2/gliner2/model.py
Normal file
692
packages/GLiNER2/gliner2/model.py
Normal file
@ -0,0 +1,692 @@
|
||||
"""
|
||||
GLiNER2 Extractor Model with Optimized Batch Processing
|
||||
|
||||
This module contains the core Extractor model that accepts PreprocessedBatch
|
||||
directly for efficient GPU-only forward passes.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from gliner.modeling.span_rep import SpanRepLayer
|
||||
from gliner2.layers import CountLSTMoE, CountLSTM, create_mlp, CountLSTMv2
|
||||
from gliner2.processor import SchemaTransformer, PreprocessedBatch, SamplingConfig
|
||||
from safetensors.torch import save_file, load_file
|
||||
from transformers import (
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
AutoModel,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
|
||||
class ExtractorConfig(PretrainedConfig):
|
||||
"""Configuration for the Extractor model."""
|
||||
model_type = "extractor"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "bert-base-uncased",
|
||||
max_width: int = 8,
|
||||
counting_layer: str = "count_lstm",
|
||||
token_pooling: str = "first",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.model_name = model_name
|
||||
self.max_width = max_width
|
||||
self.counting_layer = counting_layer
|
||||
self.token_pooling = token_pooling
|
||||
|
||||
|
||||
class Extractor(PreTrainedModel):
|
||||
"""
|
||||
GLiNER2 Extractor Model.
|
||||
|
||||
This model accepts PreprocessedBatch for efficient training.
|
||||
Use processor.collate_fn_train() to create batches.
|
||||
|
||||
Example:
|
||||
>>> processor = SchemaTransformer(model_name)
|
||||
>>> model = Extractor.from_pretrained(repo_id)
|
||||
>>>
|
||||
>>> # Training
|
||||
>>> loader = DataLoader(dataset, collate_fn=processor.collate_fn_train)
|
||||
>>> for batch in loader:
|
||||
... batch = batch.to(device)
|
||||
... loss = model(batch)["total_loss"]
|
||||
"""
|
||||
config_class = ExtractorConfig
|
||||
|
||||
def __init__(self, config: ExtractorConfig, encoder_config=None, tokenizer=None):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.max_width = config.max_width
|
||||
|
||||
# Initialize processor
|
||||
if tokenizer is not None:
|
||||
self.processor = SchemaTransformer(
|
||||
tokenizer=tokenizer,
|
||||
token_pooling=config.token_pooling
|
||||
)
|
||||
else:
|
||||
self.processor = SchemaTransformer(
|
||||
config.model_name,
|
||||
token_pooling=config.token_pooling
|
||||
)
|
||||
|
||||
# Load encoder
|
||||
if encoder_config is not None:
|
||||
self.encoder = AutoModel.from_config(encoder_config, trust_remote_code=True)
|
||||
else:
|
||||
self.encoder = AutoModel.from_pretrained(config.model_name, trust_remote_code=True)
|
||||
|
||||
self.encoder.resize_token_embeddings(len(self.processor.tokenizer))
|
||||
self.hidden_size = self.encoder.config.hidden_size
|
||||
|
||||
# Span representation layer
|
||||
self.span_rep = SpanRepLayer(
|
||||
span_mode="markerV0",
|
||||
hidden_size=self.hidden_size,
|
||||
max_width=self.max_width,
|
||||
dropout=0.1,
|
||||
)
|
||||
|
||||
# Classifier for classification tasks
|
||||
self.classifier = create_mlp(
|
||||
input_dim=self.hidden_size,
|
||||
intermediate_dims=[self.hidden_size * 2],
|
||||
output_dim=1,
|
||||
dropout=0.,
|
||||
activation="relu",
|
||||
add_layer_norm=False
|
||||
)
|
||||
|
||||
# Count prediction layer
|
||||
self.count_pred = create_mlp(
|
||||
input_dim=self.hidden_size,
|
||||
intermediate_dims=[self.hidden_size * 2],
|
||||
output_dim=20,
|
||||
dropout=0.,
|
||||
activation="relu",
|
||||
add_layer_norm=False
|
||||
)
|
||||
|
||||
# Count embedding module
|
||||
if config.counting_layer == "count_lstm":
|
||||
self.count_embed = CountLSTM(self.hidden_size)
|
||||
elif config.counting_layer == "count_lstm_moe":
|
||||
self.count_embed = CountLSTMoE(
|
||||
hidden_size=self.hidden_size,
|
||||
n_experts=4,
|
||||
ffn_mult=2,
|
||||
dropout=0.1
|
||||
)
|
||||
elif config.counting_layer == "count_lstm_v2":
|
||||
self.count_embed = CountLSTMv2(hidden_size=self.hidden_size)
|
||||
|
||||
# LoRA adapter state
|
||||
self._lora_layers = {}
|
||||
self._adapter_config = None
|
||||
|
||||
self._print_config(config)
|
||||
|
||||
def _print_config(self, config):
|
||||
print("=" * 60)
|
||||
print("🧠 Model Configuration")
|
||||
print("=" * 60)
|
||||
print(f"Encoder model : {config.model_name}")
|
||||
print(f"Counting layer : {config.counting_layer}")
|
||||
print(f"Token pooling : {config.token_pooling}")
|
||||
print("=" * 60)
|
||||
|
||||
# =========================================================================
|
||||
# Main Forward Pass
|
||||
# =========================================================================
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: PreprocessedBatch,
|
||||
return_individual_losses: bool = False
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Forward pass on preprocessed batch.
|
||||
|
||||
Args:
|
||||
batch: PreprocessedBatch from processor.collate_fn_train()
|
||||
return_individual_losses: If True, return per-sample losses
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- total_loss: Sum of all losses
|
||||
- classification_loss: Classification task loss
|
||||
- structure_loss: Span extraction loss
|
||||
- count_loss: Count prediction loss
|
||||
- batch_size: Number of valid samples
|
||||
"""
|
||||
if len(batch) == 0:
|
||||
return self._empty_loss_dict()
|
||||
|
||||
device = next(self.parameters()).device
|
||||
batch = batch.to(device)
|
||||
|
||||
# Encode batch through transformer
|
||||
all_token_embs, all_schema_embs = self._encode_batch(batch)
|
||||
|
||||
# Compute losses for each sample
|
||||
cls_losses = []
|
||||
struct_losses = []
|
||||
count_losses = []
|
||||
individual = []
|
||||
valid_samples = 0
|
||||
|
||||
for i in range(len(batch)):
|
||||
try:
|
||||
sample_losses = self._compute_sample_loss(
|
||||
token_embeddings=all_token_embs[i],
|
||||
embs_per_schema=all_schema_embs[i],
|
||||
task_types=batch.task_types[i],
|
||||
structure_labels=batch.structure_labels[i],
|
||||
device=device
|
||||
)
|
||||
|
||||
cls_losses.append(sample_losses["classification"])
|
||||
struct_losses.append(sample_losses["structure"])
|
||||
count_losses.append(sample_losses["count"])
|
||||
|
||||
if return_individual_losses:
|
||||
individual.append({
|
||||
"total_loss": (
|
||||
sample_losses["classification"] +
|
||||
sample_losses["structure"] +
|
||||
sample_losses["count"]
|
||||
).item(),
|
||||
"classification_loss": sample_losses["classification"].item(),
|
||||
"structure_loss": sample_losses["structure"].item(),
|
||||
"count_loss": sample_losses["count"].item(),
|
||||
})
|
||||
|
||||
valid_samples += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing sample {i}: {e}")
|
||||
zero = torch.tensor(0.0, device=device)
|
||||
cls_losses.append(zero)
|
||||
struct_losses.append(zero)
|
||||
count_losses.append(zero)
|
||||
|
||||
if return_individual_losses:
|
||||
individual.append({
|
||||
"total_loss": 0.0,
|
||||
"classification_loss": 0.0,
|
||||
"structure_loss": 0.0,
|
||||
"count_loss": 0.0,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
if valid_samples == 0:
|
||||
result = self._empty_loss_dict()
|
||||
if return_individual_losses:
|
||||
result["individual_losses"] = individual
|
||||
return result
|
||||
|
||||
# Aggregate losses
|
||||
total_cls = torch.stack(cls_losses).sum()
|
||||
total_struct = torch.stack(struct_losses).sum()
|
||||
total_count = torch.stack(count_losses).sum()
|
||||
total_loss = total_cls + total_struct + total_count
|
||||
|
||||
result = {
|
||||
"total_loss": total_loss,
|
||||
"classification_loss": total_cls,
|
||||
"structure_loss": total_struct,
|
||||
"count_loss": total_count,
|
||||
"batch_size": valid_samples
|
||||
}
|
||||
|
||||
if return_individual_losses:
|
||||
result["individual_losses"] = individual
|
||||
|
||||
return result
|
||||
|
||||
def _empty_loss_dict(self) -> Dict[str, torch.Tensor]:
|
||||
"""Return empty loss dictionary."""
|
||||
device = next(self.parameters()).device
|
||||
return {
|
||||
"total_loss": torch.tensor(0.0, device=device, requires_grad=True),
|
||||
"classification_loss": torch.tensor(0.0, device=device),
|
||||
"structure_loss": torch.tensor(0.0, device=device),
|
||||
"count_loss": torch.tensor(0.0, device=device),
|
||||
"batch_size": 0
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Encoding
|
||||
# =========================================================================
|
||||
|
||||
def _encode_batch(
|
||||
self,
|
||||
batch: PreprocessedBatch
|
||||
) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
|
||||
"""
|
||||
Encode batch through transformer and extract embeddings.
|
||||
|
||||
Args:
|
||||
batch: PreprocessedBatch with input_ids and attention_mask
|
||||
|
||||
Returns:
|
||||
- all_token_embs: List of (text_len, hidden) per sample
|
||||
- all_schema_embs: List of schema embeddings per sample
|
||||
"""
|
||||
# Forward through encoder
|
||||
outputs = self.encoder(
|
||||
input_ids=batch.input_ids,
|
||||
attention_mask=batch.attention_mask
|
||||
)
|
||||
token_embeddings = outputs.last_hidden_state
|
||||
|
||||
# Extract embeddings using processor
|
||||
return self.processor.extract_embeddings_from_batch(
|
||||
token_embeddings,
|
||||
batch.input_ids,
|
||||
batch
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Loss Computation
|
||||
# =========================================================================
|
||||
|
||||
def _compute_sample_loss(
|
||||
self,
|
||||
token_embeddings: torch.Tensor,
|
||||
embs_per_schema: List[List[torch.Tensor]],
|
||||
task_types: List[str],
|
||||
structure_labels: List[Any],
|
||||
device: torch.device
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Compute all losses for a single sample.
|
||||
|
||||
Args:
|
||||
token_embeddings: (text_len, hidden) text token embeddings
|
||||
embs_per_schema: List of schema embeddings
|
||||
task_types: Task type for each schema
|
||||
structure_labels: Labels for each schema
|
||||
device: Computation device
|
||||
|
||||
Returns:
|
||||
Dict with classification, structure, and count losses
|
||||
"""
|
||||
cls_loss = torch.tensor(0.0, device=device)
|
||||
struct_loss = torch.tensor(0.0, device=device)
|
||||
count_loss = torch.tensor(0.0, device=device)
|
||||
|
||||
# Compute span representations if needed
|
||||
has_span_task = any(t != "classifications" for t in task_types)
|
||||
span_info = None
|
||||
if has_span_task and token_embeddings.numel() > 0:
|
||||
span_info = self.compute_span_rep(token_embeddings)
|
||||
|
||||
all_counts = []
|
||||
all_p_embs = []
|
||||
|
||||
for i, task_type in enumerate(task_types):
|
||||
if not embs_per_schema[i]:
|
||||
continue
|
||||
|
||||
schema_emb = torch.stack(embs_per_schema[i])
|
||||
|
||||
if task_type == "classifications":
|
||||
# Classification loss
|
||||
cls_embeds = schema_emb[1:] # Skip [P] token
|
||||
logits = self.classifier(cls_embeds).squeeze(-1)
|
||||
labels = torch.tensor(structure_labels[i], dtype=torch.float, device=device)
|
||||
cls_loss = cls_loss + F.binary_cross_entropy_with_logits(
|
||||
logits, labels, reduction="sum"
|
||||
)
|
||||
else:
|
||||
# Structure loss
|
||||
structure = structure_labels[i]
|
||||
|
||||
if structure[0] == 0:
|
||||
# No instances to extract
|
||||
continue
|
||||
|
||||
if span_info is not None:
|
||||
struct_loss = struct_loss + self.compute_struct_loss(
|
||||
span_info["span_rep"],
|
||||
schema_emb,
|
||||
structure,
|
||||
span_info["span_mask"]
|
||||
)
|
||||
|
||||
# Collect for count loss (skip entities)
|
||||
if task_type != "entities":
|
||||
all_counts.append(min(structure[0], 19))
|
||||
all_p_embs.append(schema_emb[0])
|
||||
|
||||
# Count loss
|
||||
if all_counts and all_p_embs:
|
||||
counts = torch.tensor(all_counts, dtype=torch.long, device=device)
|
||||
p_embs = torch.stack(all_p_embs)
|
||||
count_loss = F.cross_entropy(self.count_pred(p_embs), counts, reduction="sum")
|
||||
|
||||
return {
|
||||
"classification": cls_loss,
|
||||
"structure": struct_loss,
|
||||
"count": count_loss
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Span Representation
|
||||
# =========================================================================
|
||||
|
||||
def compute_span_rep(self, token_embeddings: torch.Tensor) -> Dict[str, Any]:
|
||||
"""
|
||||
Compute span representations for token embeddings.
|
||||
|
||||
Args:
|
||||
token_embeddings: (text_len, hidden) token embeddings
|
||||
|
||||
Returns:
|
||||
Dict with span_rep, spans_idx, and span_mask
|
||||
"""
|
||||
text_length = len(token_embeddings)
|
||||
device = token_embeddings.device
|
||||
|
||||
spans_idx = []
|
||||
for i in range(text_length):
|
||||
for j in range(self.max_width):
|
||||
if i + j < text_length:
|
||||
spans_idx.append((i, i + j))
|
||||
else:
|
||||
spans_idx.append((-1, -1))
|
||||
|
||||
spans_idx = torch.tensor([spans_idx], dtype=torch.long, device=device)
|
||||
|
||||
# Mask invalid spans
|
||||
span_mask = (spans_idx[:, :, 0] == -1) | (spans_idx[:, :, 1] == -1)
|
||||
|
||||
# Replace invalid with (0, 0) for safe indexing
|
||||
safe_spans = torch.where(
|
||||
span_mask.unsqueeze(-1),
|
||||
torch.zeros_like(spans_idx),
|
||||
spans_idx
|
||||
)
|
||||
|
||||
# Compute span representations
|
||||
span_rep = self.span_rep(
|
||||
token_embeddings.unsqueeze(0),
|
||||
safe_spans
|
||||
).squeeze(0)
|
||||
|
||||
return {
|
||||
"span_rep": span_rep,
|
||||
"spans_idx": spans_idx,
|
||||
"span_mask": span_mask
|
||||
}
|
||||
|
||||
def compute_struct_loss(
|
||||
self,
|
||||
span_rep: torch.Tensor,
|
||||
schema_emb: torch.Tensor,
|
||||
structure: List[Any],
|
||||
span_mask: torch.Tensor,
|
||||
masking_rate: float = 0.5
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute structure extraction loss with negative span masking.
|
||||
|
||||
Args:
|
||||
span_rep: (num_spans, hidden) span representations
|
||||
schema_emb: (num_fields + 1, hidden) schema embeddings
|
||||
structure: [count, spans] structure labels
|
||||
span_mask: (1, num_spans) mask for invalid spans
|
||||
masking_rate: Probability of masking negative spans
|
||||
|
||||
Returns:
|
||||
Structure loss tensor
|
||||
"""
|
||||
gold_count = min(structure[0], 19)
|
||||
struct_proj = self.count_embed(schema_emb[1:], gold_count)
|
||||
scores = torch.einsum('lkd,bpd->bplk', span_rep, struct_proj)
|
||||
|
||||
# Create label tensor
|
||||
labs = torch.zeros_like(scores)
|
||||
|
||||
for i in range(gold_count):
|
||||
gold_spans = structure[1][i]
|
||||
for k, span in enumerate(gold_spans):
|
||||
if span is None or span == (-1, -1):
|
||||
continue
|
||||
if isinstance(span, tuple):
|
||||
start, end = span
|
||||
width = end - start
|
||||
if 0 <= start < scores.shape[2] and 0 <= width < scores.shape[3]:
|
||||
labs[i, k, start, width] = 1
|
||||
elif isinstance(span, list):
|
||||
for sub in span:
|
||||
if sub is None or sub == (-1, -1):
|
||||
continue
|
||||
start, end = sub
|
||||
width = end - start
|
||||
if 0 <= start < scores.shape[2] and 0 <= width < scores.shape[3]:
|
||||
labs[i, k, start, width] = 1
|
||||
|
||||
# Apply negative masking
|
||||
if masking_rate > 0.0 and self.training:
|
||||
negative = (labs == 0)
|
||||
random_mask = torch.rand_like(scores) < masking_rate
|
||||
to_mask = negative & random_mask
|
||||
loss_mask = (~to_mask).float()
|
||||
else:
|
||||
loss_mask = torch.ones_like(scores)
|
||||
|
||||
# Compute masked loss
|
||||
loss = F.binary_cross_entropy_with_logits(scores, labs, reduction="none")
|
||||
loss = loss * loss_mask
|
||||
loss = loss.view(loss.shape[0], loss.shape[1], -1) * (~span_mask[0]).float()
|
||||
|
||||
return loss.sum()
|
||||
|
||||
# =========================================================================
|
||||
# Hugging Face Methods
|
||||
# =========================================================================
|
||||
|
||||
def push_to_hub(self, repo_id: str, private: bool = True):
|
||||
"""Push model to Hugging Face Hub."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
self.save_pretrained(tmp_dir)
|
||||
super().push_to_hub(repo_id=repo_id, save_dir=tmp_dir, private=private)
|
||||
self.processor.tokenizer.push_to_hub(repo_id)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, repo_or_dir: str, **kwargs):
|
||||
"""
|
||||
Load model from Hugging Face Hub or local directory.
|
||||
|
||||
To use a LoRA adapter:
|
||||
1. Load the base model first
|
||||
2. Then load the adapter using model.load_adapter()
|
||||
|
||||
Example:
|
||||
model = Extractor.from_pretrained("base-model-name")
|
||||
model.load_adapter("path/to/adapter")
|
||||
"""
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
def download_or_local(repo, filename):
|
||||
if os.path.isdir(repo):
|
||||
return os.path.join(repo, filename)
|
||||
return hf_hub_download(repo, filename)
|
||||
|
||||
config_path = download_or_local(repo_or_dir, "config.json")
|
||||
config = cls.config_class.from_pretrained(config_path)
|
||||
|
||||
encoder_config_path = download_or_local(repo_or_dir, "encoder_config/config.json")
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_config_path)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(repo_or_dir)
|
||||
model = cls(config, encoder_config=encoder_config, tokenizer=tokenizer)
|
||||
|
||||
# Load weights
|
||||
try:
|
||||
model_path = download_or_local(repo_or_dir, "model.safetensors")
|
||||
state_dict = load_file(model_path)
|
||||
except Exception:
|
||||
model_path = download_or_local(repo_or_dir, "pytorch_model.bin")
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# Handle embedding size mismatch
|
||||
try:
|
||||
saved_emb = state_dict["encoder.embeddings.word_embeddings.weight"]
|
||||
model_emb = model.encoder.embeddings.word_embeddings.weight
|
||||
if saved_emb.shape[0] != model_emb.shape[0]:
|
||||
extra = model_emb.shape[0] - saved_emb.shape[0]
|
||||
state_dict["encoder.embeddings.word_embeddings.weight"] = torch.cat([
|
||||
saved_emb,
|
||||
torch.randn(extra, saved_emb.shape[1]) * 0.02
|
||||
], dim=0)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def load_adapter(self, adapter_path: str) -> 'Extractor':
|
||||
"""
|
||||
Load a LoRA adapter onto this model.
|
||||
|
||||
If an adapter is already loaded, it will be unloaded first.
|
||||
|
||||
Args:
|
||||
adapter_path: Path to adapter directory
|
||||
|
||||
Returns:
|
||||
self for method chaining
|
||||
|
||||
Example:
|
||||
model.load_adapter("./legal_adapter")
|
||||
results = model.extract_entities(text, entities)
|
||||
"""
|
||||
from gliner2.training.lora import load_lora_adapter, LoRAAdapterConfig
|
||||
|
||||
# Load adapter config
|
||||
config = LoRAAdapterConfig.load(adapter_path)
|
||||
|
||||
self._lora_layers = load_lora_adapter(self, adapter_path, auto_unload=True)
|
||||
self._adapter_config = config
|
||||
return self
|
||||
|
||||
def unload_adapter(self) -> 'Extractor':
|
||||
"""
|
||||
Unload current LoRA adapter, restoring base model.
|
||||
|
||||
Returns:
|
||||
self for method chaining
|
||||
"""
|
||||
from gliner2.training.lora import unload_lora_adapter
|
||||
|
||||
if self._lora_layers:
|
||||
unload_lora_adapter(self)
|
||||
self._lora_layers = {}
|
||||
self._adapter_config = None
|
||||
return self
|
||||
|
||||
def merge_lora(self) -> 'Extractor':
|
||||
"""
|
||||
Merge LoRA weights into base model and remove adapter structure.
|
||||
|
||||
After calling this, the model will have standard Linear layers with
|
||||
merged weights. LoRA adapters are permanently removed.
|
||||
|
||||
Returns:
|
||||
self for method chaining
|
||||
|
||||
Raises:
|
||||
ValueError: If no adapter is loaded
|
||||
|
||||
Example:
|
||||
model.load_adapter("./my_adapter")
|
||||
model.merge_lora() # Now model has merged weights, no LoRA
|
||||
model.save_pretrained("./merged_model")
|
||||
"""
|
||||
if not self._lora_layers:
|
||||
raise ValueError("No adapter loaded. Nothing to merge.")
|
||||
|
||||
from gliner2.training.lora import merge_lora_weights
|
||||
merge_lora_weights(self)
|
||||
self._lora_layers = {}
|
||||
self._adapter_config = None
|
||||
return self
|
||||
|
||||
def save_adapter(self, save_path: str) -> None:
|
||||
"""
|
||||
Save only the LoRA adapter (not full model).
|
||||
|
||||
Args:
|
||||
save_path: Directory to save adapter
|
||||
|
||||
Raises:
|
||||
ValueError: If no adapter is loaded
|
||||
"""
|
||||
if not self._lora_layers:
|
||||
raise ValueError("No adapter loaded. Use save_pretrained for full model.")
|
||||
|
||||
from gliner2.training.lora import save_lora_adapter
|
||||
save_lora_adapter(self, save_path)
|
||||
|
||||
@property
|
||||
def has_adapter(self) -> bool:
|
||||
"""Check if an adapter is currently loaded."""
|
||||
return bool(self._lora_layers)
|
||||
|
||||
@property
|
||||
def adapter_config(self):
|
||||
"""Get config of loaded adapter, or None."""
|
||||
return self._adapter_config
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: str,
|
||||
save_adapter_only: bool = False,
|
||||
merge_lora: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Save model to directory.
|
||||
|
||||
Args:
|
||||
save_directory: Where to save
|
||||
save_adapter_only: If True and adapter loaded, save only adapter
|
||||
merge_lora: If True and LoRA active, merge LoRA weights into base
|
||||
model and remove adapter structure before saving.
|
||||
WARNING: This permanently removes LoRA from the model instance.
|
||||
"""
|
||||
if save_adapter_only:
|
||||
if not self._lora_layers:
|
||||
raise ValueError("save_adapter_only=True but no adapter loaded")
|
||||
self.save_adapter(save_directory)
|
||||
return
|
||||
|
||||
# Handle LoRA merging if requested
|
||||
if merge_lora and self._lora_layers:
|
||||
self.merge_lora()
|
||||
|
||||
# Original save logic
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
self.config.save_pretrained(save_directory)
|
||||
|
||||
encoder_config_path = os.path.join(save_directory, "encoder_config")
|
||||
os.makedirs(encoder_config_path, exist_ok=True)
|
||||
self.encoder.config.save_pretrained(encoder_config_path)
|
||||
|
||||
model_path = os.path.join(save_directory, "model.safetensors")
|
||||
save_file(self.state_dict(), model_path)
|
||||
|
||||
self.processor.tokenizer.save_pretrained(save_directory)
|
||||
322
packages/GLiNER2/gliner2/old_trainer.py
Normal file
322
packages/GLiNER2/gliner2/old_trainer.py
Normal file
@ -0,0 +1,322 @@
|
||||
"""
|
||||
GLiNER2 Trainer with Optimized DataLoader-based Preprocessing
|
||||
|
||||
This module provides training utilities that leverage parallel preprocessing
|
||||
via DataLoader workers for maximum GPU utilization.
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from typing import Union, List
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from transformers import Trainer, TrainingArguments
|
||||
|
||||
from gliner2.processor import SchemaTransformer, PreprocessedBatch, SamplingConfig
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
|
||||
class ExtractorDataset(Dataset):
|
||||
"""
|
||||
Dataset for GLiNER2 training.
|
||||
|
||||
Returns (text, schema) tuples that are processed by the collate function.
|
||||
|
||||
Args:
|
||||
data_paths: Path or list of paths to JSONL training files
|
||||
shuffle: Whether to shuffle data on load (default: True)
|
||||
|
||||
JSONL Format:
|
||||
{"input": "text here", "output": {"entities": {...}, ...}}
|
||||
"""
|
||||
|
||||
def __init__(self, data_paths: Union[str, List[str]], shuffle: bool = True):
|
||||
if isinstance(data_paths, str):
|
||||
data_paths = [data_paths]
|
||||
|
||||
print(f"Loading {len(data_paths)} file(s) for training...")
|
||||
|
||||
self.data = []
|
||||
for path in data_paths:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
self.data.extend([json.loads(line) for line in f])
|
||||
|
||||
if shuffle:
|
||||
random.shuffle(self.data)
|
||||
|
||||
print(f"Loaded {len(self.data)} records from {len(data_paths)} file(s).")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx: int) -> tuple:
|
||||
"""Return (text, schema) tuple."""
|
||||
record = self.data[idx]
|
||||
return record["input"], record["output"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Collator
|
||||
# =============================================================================
|
||||
|
||||
class ExtractorDataCollator:
|
||||
"""
|
||||
Data collator that uses processor's collate function.
|
||||
|
||||
This enables parallel preprocessing via DataLoader workers.
|
||||
|
||||
Args:
|
||||
processor: SchemaTransformer instance
|
||||
is_training: Whether in training mode (enables augmentation)
|
||||
"""
|
||||
|
||||
def __init__(self, processor: SchemaTransformer, is_training: bool = True):
|
||||
self.processor = processor
|
||||
self.is_training = is_training
|
||||
|
||||
def __call__(self, batch: List[tuple]) -> PreprocessedBatch:
|
||||
"""
|
||||
Collate batch of (text, schema) tuples into PreprocessedBatch.
|
||||
|
||||
Args:
|
||||
batch: List of (text, schema) tuples from dataset
|
||||
|
||||
Returns:
|
||||
PreprocessedBatch ready for model.forward()
|
||||
"""
|
||||
if self.is_training:
|
||||
return self.processor.collate_fn_train(batch)
|
||||
else:
|
||||
return self.processor.collate_fn_inference(batch)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Trainer
|
||||
# =============================================================================
|
||||
|
||||
class ExtractorTrainer(Trainer):
|
||||
"""
|
||||
Trainer for GLiNER2 with optimized preprocessing.
|
||||
|
||||
Key features:
|
||||
- Parallel preprocessing via DataLoader workers
|
||||
- Separate learning rates for encoder and other layers
|
||||
- Optional classifier-only fine-tuning
|
||||
- FP16 support
|
||||
- Gradient accumulation
|
||||
|
||||
Example:
|
||||
>>> processor = SchemaTransformer(model_name, sampling_config=config)
|
||||
>>> collator = ExtractorDataCollator(processor, is_training=True)
|
||||
>>>
|
||||
>>> trainer = ExtractorTrainer(
|
||||
... model=model,
|
||||
... args=TrainingArguments(
|
||||
... output_dir="./output",
|
||||
... per_device_train_batch_size=32,
|
||||
... dataloader_num_workers=8, # Parallel preprocessing!
|
||||
... dataloader_pin_memory=True,
|
||||
... ),
|
||||
... train_dataset=dataset,
|
||||
... data_collator=collator,
|
||||
... encoder_lr=1e-5,
|
||||
... custom_lr=5e-4,
|
||||
... weight_decay=0.01,
|
||||
... )
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_lr: float = 1e-5,
|
||||
custom_lr: float = 5e-4,
|
||||
weight_decay: float = 0.01,
|
||||
finetune_classifier: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Initialize trainer.
|
||||
|
||||
Args:
|
||||
encoder_lr: Learning rate for encoder parameters
|
||||
custom_lr: Learning rate for non-encoder parameters
|
||||
weight_decay: Weight decay for all parameters
|
||||
finetune_classifier: If True, freeze all except classifier
|
||||
**kwargs: Arguments passed to Trainer
|
||||
"""
|
||||
self.encoder_lr = encoder_lr
|
||||
self.custom_lr = custom_lr
|
||||
self.custom_weight_decay = weight_decay
|
||||
self.finetune_classifier = finetune_classifier
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if self.finetune_classifier:
|
||||
self._freeze_non_classifier()
|
||||
|
||||
def _freeze_non_classifier(self):
|
||||
"""Freeze all parameters except classifier."""
|
||||
print("Finetuning classifier only: freezing all other parameters.")
|
||||
for name, param in self.model.named_parameters():
|
||||
if not name.startswith("classifier"):
|
||||
param.requires_grad = False
|
||||
|
||||
def create_optimizer(self):
|
||||
"""Create optimizer with separate parameter groups."""
|
||||
if self.finetune_classifier:
|
||||
# Only classifier parameters
|
||||
classifier_params = [
|
||||
p for n, p in self.model.named_parameters()
|
||||
if n.startswith("classifier") and p.requires_grad
|
||||
]
|
||||
if not classifier_params:
|
||||
raise ValueError("No trainable parameters in classifier.")
|
||||
|
||||
groups = [{
|
||||
"params": classifier_params,
|
||||
"lr": self.custom_lr,
|
||||
"weight_decay": self.custom_weight_decay,
|
||||
}]
|
||||
else:
|
||||
# Separate encoder and other parameters
|
||||
encoder_params = list(self.model.encoder.parameters())
|
||||
other_params = [
|
||||
p for n, p in self.model.named_parameters()
|
||||
if "encoder" not in n and p.requires_grad
|
||||
]
|
||||
|
||||
groups = [
|
||||
{
|
||||
"params": encoder_params,
|
||||
"lr": self.encoder_lr,
|
||||
"weight_decay": self.custom_weight_decay
|
||||
},
|
||||
{
|
||||
"params": other_params,
|
||||
"lr": self.custom_lr,
|
||||
"weight_decay": self.custom_weight_decay
|
||||
},
|
||||
]
|
||||
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
||||
self.optimizer = optimizer_cls(groups, **optimizer_kwargs)
|
||||
|
||||
def compute_loss(self, model, inputs: PreprocessedBatch, return_outputs: bool = False, **kwargs):
|
||||
"""
|
||||
Compute loss on preprocessed batch.
|
||||
|
||||
Args:
|
||||
model: The model
|
||||
inputs: PreprocessedBatch from collator
|
||||
return_outputs: Whether to return outputs dict
|
||||
|
||||
Returns:
|
||||
Loss tensor, optionally with outputs dict
|
||||
"""
|
||||
# Forward pass - inputs is already PreprocessedBatch
|
||||
outputs = model(inputs, return_individual_losses=False)
|
||||
|
||||
# Handle empty batch
|
||||
if outputs["batch_size"] == 0:
|
||||
device = next(model.parameters()).device
|
||||
loss = torch.tensor(0.0, device=device, requires_grad=True)
|
||||
else:
|
||||
loss = outputs["total_loss"]
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Training Utilities
|
||||
# =============================================================================
|
||||
|
||||
def create_training_dataloader(
|
||||
dataset: ExtractorDataset,
|
||||
processor: SchemaTransformer,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 8,
|
||||
pin_memory: bool = True,
|
||||
shuffle: bool = True,
|
||||
prefetch_factor: int = 2,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Create an optimized DataLoader for training.
|
||||
|
||||
This function creates a DataLoader configured for maximum preprocessing
|
||||
efficiency using parallel workers.
|
||||
|
||||
Args:
|
||||
dataset: ExtractorDataset instance
|
||||
processor: SchemaTransformer for preprocessing
|
||||
batch_size: Batch size
|
||||
num_workers: Number of parallel workers for preprocessing
|
||||
pin_memory: Pin memory for faster GPU transfer
|
||||
shuffle: Shuffle data each epoch
|
||||
prefetch_factor: Batches to prefetch per worker
|
||||
|
||||
Returns:
|
||||
Configured DataLoader
|
||||
|
||||
Example:
|
||||
>>> loader = create_training_dataloader(
|
||||
... dataset=train_dataset,
|
||||
... processor=processor,
|
||||
... batch_size=32,
|
||||
... num_workers=8,
|
||||
... )
|
||||
>>> for batch in loader:
|
||||
... batch = batch.to(device)
|
||||
... loss = model(batch)["total_loss"]
|
||||
"""
|
||||
collator = ExtractorDataCollator(processor, is_training=True)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
prefetch_factor=prefetch_factor if num_workers > 0 else None,
|
||||
collate_fn=collator,
|
||||
persistent_workers=num_workers > 0,
|
||||
)
|
||||
|
||||
|
||||
def create_inference_dataloader(
|
||||
texts: List[str],
|
||||
schemas: List[dict],
|
||||
processor: SchemaTransformer,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 4,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Create a DataLoader for inference.
|
||||
|
||||
Args:
|
||||
texts: List of input texts
|
||||
schemas: List of schemas (same length as texts or single schema)
|
||||
processor: SchemaTransformer for preprocessing
|
||||
batch_size: Batch size
|
||||
num_workers: Number of workers
|
||||
|
||||
Returns:
|
||||
DataLoader yielding PreprocessedBatch
|
||||
"""
|
||||
# Handle single schema for all texts
|
||||
if len(schemas) == 1:
|
||||
schemas = schemas * len(texts)
|
||||
|
||||
dataset = list(zip(texts, schemas))
|
||||
collator = ExtractorDataCollator(processor, is_training=False)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collator,
|
||||
)
|
||||
1072
packages/GLiNER2/gliner2/processor.py
Normal file
1072
packages/GLiNER2/gliner2/processor.py
Normal file
File diff suppressed because it is too large
Load Diff
0
packages/GLiNER2/gliner2/training/__init__.py
Normal file
0
packages/GLiNER2/gliner2/training/__init__.py
Normal file
1277
packages/GLiNER2/gliner2/training/data.py
Normal file
1277
packages/GLiNER2/gliner2/training/data.py
Normal file
File diff suppressed because it is too large
Load Diff
836
packages/GLiNER2/gliner2/training/lora.py
Normal file
836
packages/GLiNER2/gliner2/training/lora.py
Normal file
@ -0,0 +1,836 @@
|
||||
"""
|
||||
Custom LoRA (Low-Rank Adaptation) Implementation for GLiNER2
|
||||
=============================================================
|
||||
|
||||
Parameter-efficient fine-tuning by injecting trainable low-rank matrices
|
||||
into frozen linear layers of the encoder.
|
||||
|
||||
Based on: "LoRA: Low-Rank Adaptation of Large Language Models"
|
||||
Paper: https://arxiv.org/abs/2106.09685
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import save_file, load_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LoRA Configuration
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class LoRAConfig:
|
||||
"""
|
||||
Configuration for LoRA parameter-efficient fine-tuning.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
enabled : bool
|
||||
Whether LoRA is enabled.
|
||||
r : int
|
||||
Rank of the low-rank decomposition (bottleneck dimension).
|
||||
Higher r = more parameters but better approximation.
|
||||
Typical values: 4, 8, 16, 32, 64.
|
||||
alpha : float
|
||||
Scaling factor for LoRA updates. Final scaling is alpha/r.
|
||||
Typical values: 8, 16, 32 (often 2*r).
|
||||
dropout : float
|
||||
Dropout probability applied to LoRA path.
|
||||
target_modules : List[str]
|
||||
Module names to apply LoRA to. Supported module groups:
|
||||
|
||||
- "encoder" - Applies LoRA to query, key, value, dense layers within encoder
|
||||
- "encoder.query" - Only query layers in encoder
|
||||
- "encoder.key" - Only key layers in encoder
|
||||
- "encoder.value" - Only value layers in encoder
|
||||
- "encoder.dense" - Only dense layers in encoder
|
||||
- "span_rep" - Applies LoRA to ALL linear layers within span_rep
|
||||
- "classifier" - Applies LoRA to ALL linear layers within classifier
|
||||
- "count_embed" - Applies LoRA to ALL linear layers within count_embed
|
||||
- "count_pred" - Applies LoRA to ALL linear layers within count_pred
|
||||
|
||||
Examples:
|
||||
- ["encoder"] - all encoder layers (query, key, value, dense)
|
||||
- ["encoder.query", "encoder.key", "encoder.value"] - only attention layers
|
||||
- ["encoder.dense"] - only dense (FFN) layers in encoder
|
||||
- ["encoder", "span_rep", "classifier"] - encoder + task heads
|
||||
- ["classifier"] - classifier fine-tuning only
|
||||
"""
|
||||
enabled: bool = False
|
||||
r: int = 8
|
||||
alpha: float = 16.0
|
||||
dropout: float = 0.0
|
||||
target_modules: List[str] = field(default_factory=lambda: ["encoder"])
|
||||
|
||||
def __post_init__(self):
|
||||
if self.r <= 0:
|
||||
raise ValueError(f"LoRA rank must be > 0, got {self.r}")
|
||||
if self.alpha <= 0:
|
||||
raise ValueError(f"LoRA alpha must be > 0, got {self.alpha}")
|
||||
if not 0 <= self.dropout < 1:
|
||||
raise ValueError(f"LoRA dropout must be in [0, 1), got {self.dropout}")
|
||||
if self.enabled and not self.target_modules:
|
||||
raise ValueError("target_modules cannot be empty when LoRA is enabled")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAAdapterConfig:
|
||||
"""
|
||||
Configuration for a saved LoRA adapter.
|
||||
|
||||
This is the config that gets saved with adapter-only checkpoints.
|
||||
"""
|
||||
adapter_type: str = "lora"
|
||||
adapter_version: str = "1.0"
|
||||
lora_r: int = 8
|
||||
lora_alpha: float = 16.0
|
||||
lora_dropout: float = 0.0
|
||||
target_modules: List[str] = field(default_factory=list)
|
||||
created_at: str = ""
|
||||
|
||||
def save(self, path: Union[str, Path]) -> None:
|
||||
"""Save adapter config to JSON file."""
|
||||
path = Path(path)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
config_path = path / "adapter_config.json"
|
||||
|
||||
# Set created_at if not set
|
||||
if not self.created_at:
|
||||
self.created_at = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(asdict(self), f, indent=2)
|
||||
|
||||
logger.info(f"Saved adapter config to {config_path}")
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path]) -> 'LoRAAdapterConfig':
|
||||
"""Load adapter config from JSON file or directory."""
|
||||
path = Path(path)
|
||||
|
||||
# If path is a directory, look for adapter_config.json
|
||||
if path.is_dir():
|
||||
config_path = path / "adapter_config.json"
|
||||
else:
|
||||
config_path = path
|
||||
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Adapter config not found at {config_path}")
|
||||
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
return cls(**config_dict)
|
||||
|
||||
@classmethod
|
||||
def is_adapter_path(cls, path: Union[str, Path]) -> bool:
|
||||
"""Check if path contains an adapter."""
|
||||
path = Path(path)
|
||||
|
||||
# Check for adapter_config.json
|
||||
if path.is_dir():
|
||||
return (path / "adapter_config.json").exists()
|
||||
else:
|
||||
return path.name == "adapter_config.json" and path.exists()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LoRA Layer
|
||||
# =============================================================================
|
||||
|
||||
class LoRALayer(nn.Module):
|
||||
"""
|
||||
LoRA-enhanced Linear layer.
|
||||
|
||||
Computes: output = W*x + (B*A*x) * scaling
|
||||
Where:
|
||||
- W is the frozen original weight
|
||||
- A, B are trainable low-rank matrices
|
||||
- scaling = alpha / r
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base_layer : nn.Linear
|
||||
Original linear layer (will be frozen).
|
||||
r : int
|
||||
Rank of low-rank decomposition.
|
||||
alpha : float
|
||||
LoRA scaling factor.
|
||||
dropout : float
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Linear,
|
||||
r: int,
|
||||
alpha: float,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.r = r
|
||||
self.alpha = alpha
|
||||
self.scaling = alpha / r
|
||||
|
||||
in_features = base_layer.in_features
|
||||
out_features = base_layer.out_features
|
||||
|
||||
# Store frozen base layer
|
||||
self.base_layer = base_layer
|
||||
for param in self.base_layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Get device from base layer to ensure LoRA parameters are on same device
|
||||
device = next(base_layer.parameters()).device
|
||||
|
||||
# LoRA low-rank matrices
|
||||
# A: (r, in_features) - initialized with small random values
|
||||
# B: (out_features, r) - initialized to zero (no change at start)
|
||||
self.lora_A = nn.Parameter(torch.zeros(r, in_features, device=device))
|
||||
self.lora_B = nn.Parameter(torch.zeros(out_features, r, device=device))
|
||||
|
||||
# Initialize A with Kaiming uniform (same as nn.Linear default)
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
# B stays zero-initialized
|
||||
|
||||
# Dropout
|
||||
self.lora_dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()
|
||||
|
||||
# Flag to track if weights are merged
|
||||
self.merged = False
|
||||
|
||||
# Expose base layer attributes for compatibility
|
||||
@property
|
||||
def weight(self):
|
||||
"""Expose weight from base layer for compatibility."""
|
||||
return self.base_layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
"""Expose bias from base layer for compatibility."""
|
||||
return self.base_layer.bias
|
||||
|
||||
@property
|
||||
def in_features(self):
|
||||
"""Expose in_features from base layer."""
|
||||
return self.base_layer.in_features
|
||||
|
||||
@property
|
||||
def out_features(self):
|
||||
"""Expose out_features from base layer."""
|
||||
return self.base_layer.out_features
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass with LoRA.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor of shape (..., in_features).
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor of shape (..., out_features).
|
||||
"""
|
||||
# Base output from frozen weights
|
||||
base_output = self.base_layer(x)
|
||||
|
||||
if self.merged:
|
||||
# Weights already merged, just use base layer
|
||||
return base_output
|
||||
|
||||
# LoRA path: x -> dropout -> A -> B -> scale
|
||||
# Equivalent to: (x @ A.T) @ B.T * scaling
|
||||
lora_output = self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T
|
||||
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
def merge_weights(self):
|
||||
"""Merge LoRA weights (B @ A) into base layer weights."""
|
||||
if self.merged:
|
||||
# Already merged, silently skip
|
||||
return
|
||||
|
||||
with torch.no_grad():
|
||||
# Compute LoRA contribution: B @ A * scaling
|
||||
lora_weight = (self.lora_B @ self.lora_A) * self.scaling
|
||||
# Add to base weight
|
||||
self.base_layer.weight.data += lora_weight
|
||||
|
||||
self.merged = True
|
||||
logger.debug(f"Merged LoRA weights (r={self.r}) into base layer")
|
||||
|
||||
def unmerge_weights(self):
|
||||
"""Separate LoRA weights from base layer (reverse of merge)."""
|
||||
if not self.merged:
|
||||
# Not merged, silently skip
|
||||
return
|
||||
|
||||
with torch.no_grad():
|
||||
# Subtract LoRA contribution
|
||||
lora_weight = (self.lora_B @ self.lora_A) * self.scaling
|
||||
self.base_layer.weight.data -= lora_weight
|
||||
|
||||
self.merged = False
|
||||
logger.debug(f"Unmerged LoRA weights (r={self.r}) from base layer")
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"r={self.r}, alpha={self.alpha}, scaling={self.scaling:.4f}, merged={self.merged}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LoRA Application Functions
|
||||
# =============================================================================
|
||||
|
||||
# Module-specific patterns for LoRA application
|
||||
ENCODER_PATTERNS = ["query", "key", "value", "dense"]
|
||||
ALL_LINEAR_MODULES = ["span_rep", "classifier", "count_embed", "count_pred"]
|
||||
|
||||
def apply_lora_to_model(
|
||||
model: nn.Module,
|
||||
config: LoRAConfig,
|
||||
) -> Tuple[nn.Module, Dict[str, LoRALayer]]:
|
||||
"""
|
||||
Apply LoRA to linear layers based on module groups in target_modules.
|
||||
|
||||
Module group behavior:
|
||||
- "encoder": Applies LoRA to query, key, value, dense layers within encoder
|
||||
- "encoder.query": Only query layers in encoder
|
||||
- "encoder.key": Only key layers in encoder
|
||||
- "encoder.value": Only value layers in encoder
|
||||
- "encoder.dense": Only dense layers in encoder
|
||||
- "span_rep", "classifier", "count_embed", "count_pred": Applies LoRA to ALL linear layers
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
The model to apply LoRA to.
|
||||
config : LoRAConfig
|
||||
LoRA configuration.
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : nn.Module
|
||||
Modified model with LoRA layers.
|
||||
lora_layers : Dict[str, LoRALayer]
|
||||
Dictionary mapping layer names to LoRA layers.
|
||||
"""
|
||||
if not config.enabled:
|
||||
logger.info("LoRA is disabled, skipping application")
|
||||
return model, {}
|
||||
|
||||
lora_layers = {}
|
||||
|
||||
def _should_apply_lora(local_name: str, full_path: str) -> bool:
|
||||
"""
|
||||
Check if LoRA should be applied based on module groups.
|
||||
|
||||
Args:
|
||||
local_name: Local module name (e.g., "query", "linear")
|
||||
full_path: Full path from model root (e.g., "encoder.layer.0.attention.self.query")
|
||||
|
||||
Returns:
|
||||
True if LoRA should be applied to this layer
|
||||
"""
|
||||
for target in config.target_modules:
|
||||
if target == "encoder":
|
||||
# For encoder, apply only to specific patterns
|
||||
if full_path.startswith("encoder."):
|
||||
# Check if local name matches encoder patterns
|
||||
if any(pattern in local_name for pattern in ENCODER_PATTERNS):
|
||||
return True
|
||||
elif target.startswith("encoder."):
|
||||
# Specific encoder layer (e.g., "encoder.query", "encoder.dense")
|
||||
layer_name = target.split(".", 1)[1] # Extract "query" from "encoder.query"
|
||||
if full_path.startswith("encoder.") and layer_name in local_name:
|
||||
return True
|
||||
elif target in ALL_LINEAR_MODULES:
|
||||
# For these modules, apply to ALL linear layers within
|
||||
if full_path.startswith(f"{target}."):
|
||||
return True
|
||||
return False
|
||||
|
||||
# Recursively find and replace modules
|
||||
def _inject_lora_recursive(module: nn.Module, prefix: str = ""):
|
||||
for name, child in module.named_children():
|
||||
full_name = f"{prefix}.{name}" if prefix else name
|
||||
|
||||
# Apply LoRA to matching Linear layers
|
||||
if isinstance(child, nn.Linear) and _should_apply_lora(name, full_name):
|
||||
# Replace with LoRA layer
|
||||
lora_layer = LoRALayer(
|
||||
base_layer=child,
|
||||
r=config.r,
|
||||
alpha=config.alpha,
|
||||
dropout=config.dropout,
|
||||
)
|
||||
setattr(module, name, lora_layer)
|
||||
lora_layers[full_name] = lora_layer
|
||||
|
||||
logger.debug(
|
||||
f"Applied LoRA to {full_name} "
|
||||
f"(in={child.in_features}, out={child.out_features})"
|
||||
)
|
||||
else:
|
||||
# Recurse into child
|
||||
_inject_lora_recursive(child, full_name)
|
||||
|
||||
_inject_lora_recursive(model)
|
||||
|
||||
if not lora_layers:
|
||||
logger.warning(
|
||||
f"No LoRA layers were applied. Target modules {config.target_modules} "
|
||||
f"not found. Check your target_modules configuration."
|
||||
)
|
||||
else:
|
||||
logger.info(f"Applied LoRA to {len(lora_layers)} layers")
|
||||
|
||||
return model, lora_layers
|
||||
|
||||
|
||||
def get_lora_parameters(model: nn.Module) -> List[nn.Parameter]:
|
||||
"""
|
||||
Extract all LoRA parameters (lora_A and lora_B) from model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
Model with LoRA layers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[nn.Parameter]
|
||||
List of LoRA parameters.
|
||||
"""
|
||||
lora_params = []
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoRALayer):
|
||||
lora_params.extend([module.lora_A, module.lora_B])
|
||||
return lora_params
|
||||
|
||||
|
||||
def get_lora_state_dict(model: nn.Module) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Get state dict containing only LoRA parameters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
Model with LoRA layers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, torch.Tensor]
|
||||
State dict with LoRA parameters only.
|
||||
"""
|
||||
lora_state = {}
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoRALayer):
|
||||
lora_state[f"{name}.lora_A"] = module.lora_A.data
|
||||
lora_state[f"{name}.lora_B"] = module.lora_B.data
|
||||
return lora_state
|
||||
|
||||
|
||||
def merge_lora_weights(model: nn.Module) -> int:
|
||||
"""
|
||||
Merge all LoRA weights into base layers and remove LoRA structure.
|
||||
|
||||
After calling this, the model will have standard Linear layers with
|
||||
merged weights. LoRA adapters are removed from the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
Model with LoRA layers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of layers merged and removed.
|
||||
"""
|
||||
count = 0
|
||||
already_merged = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoRALayer):
|
||||
if not module.merged:
|
||||
module.merge_weights()
|
||||
count += 1
|
||||
else:
|
||||
already_merged += 1
|
||||
|
||||
if count > 0:
|
||||
logger.debug(f"Merged LoRA weights in {count} layers")
|
||||
if already_merged > 0:
|
||||
logger.debug(f"Skipped {already_merged} layers (already merged)")
|
||||
|
||||
# Remove LoRA layers after merging
|
||||
if count > 0 or already_merged > 0:
|
||||
remove_lora_from_model(model)
|
||||
logger.info(f"Merged and removed LoRA layers from model")
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def unmerge_lora_weights(model: nn.Module) -> int:
|
||||
"""
|
||||
Unmerge all LoRA weights from their base layers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
Model with LoRA layers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of layers unmerged.
|
||||
"""
|
||||
count = 0
|
||||
not_merged = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoRALayer):
|
||||
if module.merged:
|
||||
module.unmerge_weights()
|
||||
count += 1
|
||||
else:
|
||||
not_merged += 1
|
||||
|
||||
if count > 0:
|
||||
logger.debug(f"Unmerged LoRA weights in {count} layers")
|
||||
if not_merged > 0:
|
||||
logger.debug(f"Skipped {not_merged} layers (not merged)")
|
||||
return count
|
||||
|
||||
|
||||
def count_lora_parameters(model: nn.Module) -> Tuple[int, int, float]:
|
||||
"""
|
||||
Count LoRA parameters vs total parameters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
Model with LoRA layers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
lora_params : int
|
||||
Number of trainable LoRA parameters.
|
||||
total_params : int
|
||||
Total number of model parameters.
|
||||
percentage : float
|
||||
Percentage of trainable parameters.
|
||||
"""
|
||||
lora_params = sum(p.numel() for p in get_lora_parameters(model))
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
percentage = (lora_params / total_params * 100) if total_params > 0 else 0.0
|
||||
|
||||
return lora_params, total_params, percentage
|
||||
|
||||
|
||||
def print_lora_info(model: nn.Module, config: LoRAConfig):
|
||||
"""
|
||||
Print detailed LoRA configuration and parameter statistics.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
Model with LoRA layers.
|
||||
config : LoRAConfig
|
||||
LoRA configuration.
|
||||
"""
|
||||
lora_params, total_params, percentage = count_lora_parameters(model)
|
||||
|
||||
# Count LoRA layers
|
||||
num_lora_layers = sum(1 for m in model.modules() if isinstance(m, LoRALayer))
|
||||
|
||||
print("=" * 70)
|
||||
print("🔧 LoRA Configuration")
|
||||
print("=" * 70)
|
||||
print(f"Enabled : {config.enabled}")
|
||||
print(f"Rank (r) : {config.r}")
|
||||
print(f"Alpha : {config.alpha}")
|
||||
print(f"Scaling (α/r) : {config.alpha / config.r:.4f}")
|
||||
print(f"Dropout : {config.dropout}")
|
||||
print(f"Target modules : {', '.join(config.target_modules)}")
|
||||
print(f"LoRA layers : {num_lora_layers}")
|
||||
print("-" * 70)
|
||||
print(f"Trainable params : {lora_params:,} / {total_params:,} ({percentage:.2f}%)")
|
||||
print(f"Memory savings : ~{100 - percentage:.1f}% fewer gradients")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
def remove_lora_from_model(model: nn.Module) -> nn.Module:
|
||||
"""
|
||||
Remove LoRA layers and restore original Linear layers.
|
||||
Useful for inference with merged weights.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
Model with LoRA layers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
nn.Module
|
||||
Model with LoRA layers replaced by standard Linear layers.
|
||||
"""
|
||||
def _remove_lora_recursive(module: nn.Module):
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, LoRALayer):
|
||||
# Ensure weights are merged
|
||||
if not child.merged:
|
||||
child.merge_weights()
|
||||
# Replace LoRALayer with its base layer
|
||||
setattr(module, name, child.base_layer)
|
||||
logger.debug(f"Removed LoRA from {name}, restored base layer")
|
||||
else:
|
||||
_remove_lora_recursive(child)
|
||||
|
||||
_remove_lora_recursive(model)
|
||||
logger.info("Removed all LoRA layers from model")
|
||||
return model
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Adapter Management Functions
|
||||
# =============================================================================
|
||||
|
||||
def save_lora_adapter(
|
||||
model: nn.Module,
|
||||
save_path: Union[str, Path],
|
||||
) -> None:
|
||||
"""
|
||||
Save only LoRA adapter weights and config.
|
||||
|
||||
Args:
|
||||
model: Model with LoRA layers (must NOT be merged)
|
||||
save_path: Directory to save adapter
|
||||
|
||||
Saves:
|
||||
- adapter_config.json
|
||||
- adapter_weights.safetensors
|
||||
"""
|
||||
save_path = Path(save_path)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Collect LoRA weights and config
|
||||
lora_state = {}
|
||||
lora_config = None
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoRALayer):
|
||||
if module.merged:
|
||||
raise ValueError(
|
||||
"Cannot save adapter with merged weights. "
|
||||
"Call unmerge_lora_weights() first."
|
||||
)
|
||||
# Save LoRA matrices with full path from model root
|
||||
lora_state[f"{name}.lora_A"] = module.lora_A.data
|
||||
lora_state[f"{name}.lora_B"] = module.lora_B.data
|
||||
|
||||
# Extract config from first LoRA layer
|
||||
if lora_config is None:
|
||||
lora_config = {
|
||||
"lora_r": module.r,
|
||||
"lora_alpha": module.alpha,
|
||||
"lora_dropout": module.lora_dropout.p if hasattr(module.lora_dropout, 'p') else 0.0,
|
||||
}
|
||||
|
||||
if not lora_state:
|
||||
raise ValueError("No LoRA layers found in model")
|
||||
|
||||
# Save weights
|
||||
weights_path = save_path / "adapter_weights.safetensors"
|
||||
save_file(lora_state, str(weights_path))
|
||||
logger.info(f"Saved {len(lora_state)} LoRA tensors to {weights_path}")
|
||||
|
||||
# Determine target modules from layer names
|
||||
# Extract top-level module names (encoder, span_rep, classifier, etc.)
|
||||
target_modules = set()
|
||||
for key in lora_state.keys():
|
||||
# Extract first level module from full path
|
||||
# e.g., "encoder.layer.0.attention.self.query.lora_A" -> "encoder"
|
||||
# e.g., "span_rep.project_start.0.lora_A" -> "span_rep"
|
||||
parts = key.split(".")
|
||||
if len(parts) > 0:
|
||||
# Get the first level module name
|
||||
module_name = parts[0]
|
||||
target_modules.add(module_name)
|
||||
|
||||
# Create and save adapter config
|
||||
adapter_config = LoRAAdapterConfig(
|
||||
adapter_type="lora",
|
||||
adapter_version="1.0",
|
||||
lora_r=lora_config["lora_r"],
|
||||
lora_alpha=lora_config["lora_alpha"],
|
||||
lora_dropout=lora_config["lora_dropout"],
|
||||
target_modules=sorted(list(target_modules)),
|
||||
created_at=datetime.utcnow().isoformat() + "Z"
|
||||
)
|
||||
adapter_config.save(save_path)
|
||||
|
||||
logger.info(f"Saved LoRA adapter to {save_path}")
|
||||
|
||||
|
||||
def load_lora_adapter(
|
||||
model: nn.Module,
|
||||
adapter_path: Union[str, Path],
|
||||
auto_unload: bool = True,
|
||||
) -> Dict[str, LoRALayer]:
|
||||
"""
|
||||
Load LoRA adapter onto model.
|
||||
|
||||
Args:
|
||||
model: Base model (should not have LoRA applied)
|
||||
adapter_path: Path to adapter directory
|
||||
auto_unload: If True, unload existing adapter first
|
||||
|
||||
Returns:
|
||||
Dict of LoRA layers that were applied
|
||||
"""
|
||||
adapter_path = Path(adapter_path)
|
||||
|
||||
# Load adapter config
|
||||
adapter_config = LoRAAdapterConfig.load(adapter_path)
|
||||
|
||||
# Unload existing adapter if requested
|
||||
if auto_unload and has_lora_adapter(model):
|
||||
logger.info("Unloading existing adapter before loading new one")
|
||||
unload_lora_adapter(model)
|
||||
|
||||
# Load adapter weights
|
||||
weights_path = adapter_path / "adapter_weights.safetensors"
|
||||
if not weights_path.exists():
|
||||
raise FileNotFoundError(f"Adapter weights not found at {weights_path}")
|
||||
|
||||
lora_state = load_file(str(weights_path))
|
||||
logger.info(f"Loaded {len(lora_state)} LoRA tensors from {weights_path}")
|
||||
|
||||
# Apply LoRA to matching layers
|
||||
lora_config = LoRAConfig(
|
||||
enabled=True,
|
||||
r=adapter_config.lora_r,
|
||||
alpha=adapter_config.lora_alpha,
|
||||
dropout=adapter_config.lora_dropout,
|
||||
target_modules=adapter_config.target_modules,
|
||||
)
|
||||
|
||||
model, lora_layers = apply_lora_to_model(model, lora_config)
|
||||
|
||||
# Load saved weights into LoRA layers
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoRALayer):
|
||||
lora_a_key = f"{name}.lora_A"
|
||||
lora_b_key = f"{name}.lora_B"
|
||||
|
||||
if lora_a_key in lora_state and lora_b_key in lora_state:
|
||||
# Move loaded tensors to the same device as the module
|
||||
device = next(module.parameters()).device
|
||||
module.lora_A.data = lora_state[lora_a_key].to(device)
|
||||
module.lora_B.data = lora_state[lora_b_key].to(device)
|
||||
logger.debug(f"Loaded weights for {name}")
|
||||
else:
|
||||
logger.warning(f"No saved weights found for {name}")
|
||||
|
||||
logger.info(f"Loaded LoRA adapter from {adapter_path}")
|
||||
return lora_layers
|
||||
|
||||
|
||||
def unload_lora_adapter(model: nn.Module) -> int:
|
||||
"""
|
||||
Remove all LoRA layers, restoring original Linear layers.
|
||||
|
||||
Unlike remove_lora_from_model, this does NOT merge weights.
|
||||
Just removes LoRA layers entirely.
|
||||
|
||||
Returns:
|
||||
Number of layers unloaded
|
||||
"""
|
||||
count = 0
|
||||
|
||||
def _get_parent_module(model: nn.Module, full_name: str) -> Tuple[nn.Module, str]:
|
||||
"""Get parent module and child name from full module path."""
|
||||
parts = full_name.split('.')
|
||||
parent = model
|
||||
for part in parts[:-1]:
|
||||
parent = getattr(parent, part)
|
||||
return parent, parts[-1]
|
||||
|
||||
# Collect all LoRA layers first (to avoid modifying dict during iteration)
|
||||
lora_layers = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoRALayer):
|
||||
lora_layers.append((name, module))
|
||||
|
||||
# Remove LoRA layers
|
||||
for name, lora_layer in lora_layers:
|
||||
parent, child_name = _get_parent_module(model, name)
|
||||
# Replace with original base_layer (no merge)
|
||||
setattr(parent, child_name, lora_layer.base_layer)
|
||||
count += 1
|
||||
logger.debug(f"Unloaded LoRA from {name}")
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Unloaded {count} LoRA layers")
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def has_lora_adapter(model: nn.Module) -> bool:
|
||||
"""Check if model has LoRA layers applied."""
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoRALayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_adapter_config(model: nn.Module) -> Optional[LoRAAdapterConfig]:
|
||||
"""
|
||||
Get config of currently loaded adapter, if any.
|
||||
|
||||
Note: This reconstructs config from LoRA layers.
|
||||
The actual adapter config is stored in model._adapter_config
|
||||
when loaded via model.load_adapter().
|
||||
"""
|
||||
if not has_lora_adapter(model):
|
||||
return None
|
||||
|
||||
# Extract config from first LoRA layer
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoRALayer):
|
||||
target_modules = set()
|
||||
# Collect all target module groups (top-level modules)
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, LoRALayer):
|
||||
# Extract first level module name
|
||||
parts = name.split(".")
|
||||
if parts:
|
||||
target_modules.add(parts[0])
|
||||
|
||||
return LoRAAdapterConfig(
|
||||
adapter_type="lora",
|
||||
adapter_version="1.0",
|
||||
lora_r=module.r,
|
||||
lora_alpha=module.alpha,
|
||||
lora_dropout=module.lora_dropout.p if hasattr(module.lora_dropout, 'p') else 0.0,
|
||||
target_modules=sorted(list(target_modules)),
|
||||
created_at=""
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
1409
packages/GLiNER2/gliner2/training/trainer.py
Normal file
1409
packages/GLiNER2/gliner2/training/trainer.py
Normal file
File diff suppressed because it is too large
Load Diff
25
packages/GLiNER2/pyproject.toml
Normal file
25
packages/GLiNER2/pyproject.toml
Normal file
@ -0,0 +1,25 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["gliner2", "gliner2.*"]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = {attr = "gliner2.__version__"}
|
||||
|
||||
[project]
|
||||
name = "gliner2"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
|
||||
maintainers = [
|
||||
{name = "Urchade Zaratiana"},
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"gliner",
|
||||
"pydantic>=2.0.0",
|
||||
]
|
||||
|
||||
dynamic = ["version"]
|
||||
663
packages/GLiNER2/tutorial/1-classification.md
Normal file
663
packages/GLiNER2/tutorial/1-classification.md
Normal file
@ -0,0 +1,663 @@
|
||||
# GLiNER2 Classification Tutorial
|
||||
|
||||
This tutorial covers all the ways to perform text classification with GLiNER2, from simple single-label classification to complex multi-label tasks with custom configurations.
|
||||
|
||||
## Table of Contents
|
||||
- [Setup](#setup)
|
||||
- [Single-Label Classification](#single-label-classification)
|
||||
- [Multi-Label Classification](#multi-label-classification)
|
||||
- [Classification with Descriptions](#classification-with-descriptions)
|
||||
- [Using the Quick API](#using-the-quick-api)
|
||||
- [Multiple Classification Tasks](#multiple-classification-tasks)
|
||||
- [Advanced Configurations](#advanced-configurations)
|
||||
- [Best Practices](#best-practices)
|
||||
|
||||
## Setup
|
||||
|
||||
```python
|
||||
from gliner2 import GLiNER2
|
||||
|
||||
# Load the pre-trained model
|
||||
extractor = GLiNER2.from_pretrained("your-model-name")
|
||||
```
|
||||
|
||||
## Single-Label Classification
|
||||
|
||||
The simplest form - classify text into one of several categories.
|
||||
|
||||
### Basic Example
|
||||
|
||||
```python
|
||||
# Define the schema
|
||||
schema = extractor.create_schema().classification(
|
||||
"sentiment",
|
||||
["positive", "negative", "neutral"]
|
||||
)
|
||||
|
||||
# Extract
|
||||
text = "This product exceeded my expectations! Absolutely love it."
|
||||
results = extractor.extract(text, schema)
|
||||
print(results)
|
||||
# Expected output: {'sentiment': 'positive'}
|
||||
```
|
||||
|
||||
### With Confidence Scores
|
||||
|
||||
```python
|
||||
# Same schema as above
|
||||
schema = extractor.create_schema().classification(
|
||||
"sentiment",
|
||||
["positive", "negative", "neutral"]
|
||||
)
|
||||
|
||||
text = "The service was okay, nothing special but not bad either."
|
||||
results = extractor.extract(text, schema, include_confidence=True)
|
||||
print(results)
|
||||
# Expected output: {'sentiment': {'label': 'neutral', 'confidence': 0.82}}
|
||||
```
|
||||
|
||||
## Multi-Label Classification
|
||||
|
||||
When text can belong to multiple categories simultaneously.
|
||||
|
||||
```python
|
||||
# Multi-label classification
|
||||
schema = extractor.create_schema().classification(
|
||||
"topics",
|
||||
["technology", "business", "health", "politics", "sports"],
|
||||
multi_label=True,
|
||||
cls_threshold=0.3 # Lower threshold for multi-label
|
||||
)
|
||||
|
||||
text = "Apple announced new health monitoring features in their latest smartwatch, boosting their stock price."
|
||||
results = extractor.extract(text, schema)
|
||||
print(results)
|
||||
# Expected output: {'topics': ['technology', 'business', 'health']}
|
||||
|
||||
# With confidence scores
|
||||
results = extractor.extract(text, schema, include_confidence=True)
|
||||
print(results)
|
||||
# Expected output: {'topics': [
|
||||
# {'label': 'technology', 'confidence': 0.92},
|
||||
# {'label': 'business', 'confidence': 0.78},
|
||||
# {'label': 'health', 'confidence': 0.65}
|
||||
# ]}
|
||||
```
|
||||
|
||||
## Classification with Descriptions
|
||||
|
||||
Adding descriptions significantly improves accuracy by providing context.
|
||||
|
||||
```python
|
||||
# With label descriptions
|
||||
schema = extractor.create_schema().classification(
|
||||
"document_type",
|
||||
{
|
||||
"invoice": "A bill for goods or services with payment details",
|
||||
"receipt": "Proof of payment for a completed transaction",
|
||||
"contract": "Legal agreement between parties with terms and conditions",
|
||||
"proposal": "Document outlining suggested plans or services with pricing"
|
||||
}
|
||||
)
|
||||
|
||||
text = "Please find attached the itemized bill for consulting services rendered in Q3 2024. Payment is due within 30 days."
|
||||
results = extractor.extract(text, schema)
|
||||
print(results)
|
||||
# Expected output: {'document_type': 'invoice'}
|
||||
|
||||
# Another example
|
||||
text2 = "Thank you for your payment of $500. This confirms your transaction was completed on March 1st, 2024."
|
||||
results2 = extractor.extract(text2, schema)
|
||||
print(results2)
|
||||
# Expected output: {'document_type': 'receipt'}
|
||||
```
|
||||
|
||||
## Using the Quick API
|
||||
|
||||
For simple classification tasks without building a schema.
|
||||
|
||||
### Single Task
|
||||
|
||||
```python
|
||||
text = "The new AI model shows remarkable performance improvements."
|
||||
results = extractor.classify_text(
|
||||
text,
|
||||
{"sentiment": ["positive", "negative", "neutral"]}
|
||||
)
|
||||
print(results)
|
||||
# Expected output: {'sentiment': 'positive'}
|
||||
|
||||
# Another example
|
||||
text2 = "The software keeps crashing and customer support is unresponsive."
|
||||
results2 = extractor.classify_text(
|
||||
text2,
|
||||
{"sentiment": ["positive", "negative", "neutral"]}
|
||||
)
|
||||
print(results2)
|
||||
# Expected output: {'sentiment': 'negative'}
|
||||
```
|
||||
|
||||
### Multiple Tasks
|
||||
|
||||
```python
|
||||
text = "Breaking: Tech giant announces major layoffs amid market downturn"
|
||||
results = extractor.classify_text(
|
||||
text,
|
||||
{
|
||||
"sentiment": ["positive", "negative", "neutral"],
|
||||
"urgency": ["high", "medium", "low"],
|
||||
"category": {
|
||||
"labels": ["tech", "finance", "politics", "sports"],
|
||||
"multi_label": False
|
||||
}
|
||||
}
|
||||
)
|
||||
print(results)
|
||||
# Expected output: {
|
||||
# 'sentiment': 'negative',
|
||||
# 'urgency': 'high',
|
||||
# 'category': 'tech'
|
||||
# }
|
||||
```
|
||||
|
||||
### Multi-Label with Config
|
||||
|
||||
```python
|
||||
text = "The smartphone features an amazing camera but disappointing battery life and overheats frequently."
|
||||
results = extractor.classify_text(
|
||||
text,
|
||||
{
|
||||
"product_aspects": {
|
||||
"labels": ["camera", "battery", "display", "performance", "design", "heating"],
|
||||
"multi_label": True,
|
||||
"cls_threshold": 0.4
|
||||
}
|
||||
}
|
||||
)
|
||||
print(results)
|
||||
# Expected output: {'product_aspects': ['camera', 'battery', 'heating']}
|
||||
|
||||
# Another example
|
||||
text2 = "Beautiful design with vibrant display, though the camera could be better."
|
||||
results2 = extractor.classify_text(
|
||||
text2,
|
||||
{
|
||||
"product_aspects": {
|
||||
"labels": ["camera", "battery", "display", "performance", "design", "heating"],
|
||||
"multi_label": True,
|
||||
"cls_threshold": 0.4
|
||||
}
|
||||
}
|
||||
)
|
||||
print(results2)
|
||||
# Expected output: {'product_aspects': ['design', 'display', 'camera']}
|
||||
```
|
||||
|
||||
## Multiple Classification Tasks
|
||||
|
||||
You can include multiple classification tasks in a single schema for comprehensive text analysis.
|
||||
|
||||
### Basic Multiple Classifications
|
||||
|
||||
```python
|
||||
# Multiple independent classifications
|
||||
schema = (extractor.create_schema()
|
||||
.classification("sentiment", ["positive", "negative", "neutral"])
|
||||
.classification("language", ["english", "spanish", "french", "german", "other"])
|
||||
.classification("formality", ["formal", "informal", "semi-formal"])
|
||||
.classification("intent", ["question", "statement", "request", "complaint"])
|
||||
)
|
||||
|
||||
text = "Could you please help me with my order? The service has been disappointing."
|
||||
results = extractor.extract(text, schema)
|
||||
print(results)
|
||||
# Expected output: {
|
||||
# 'sentiment': 'negative',
|
||||
# 'language': 'english',
|
||||
# 'formality': 'formal',
|
||||
# 'intent': 'question'
|
||||
# }
|
||||
|
||||
# Another example
|
||||
text2 = "Hey! Just wanted to say your product rocks! 🎉"
|
||||
results2 = extractor.extract(text2, schema)
|
||||
print(results2)
|
||||
# Expected output: {
|
||||
# 'sentiment': 'positive',
|
||||
# 'language': 'english',
|
||||
# 'formality': 'informal',
|
||||
# 'intent': 'statement'
|
||||
# }
|
||||
```
|
||||
|
||||
### Mixed Single and Multi-Label Classifications
|
||||
|
||||
```python
|
||||
# Combine different classification types
|
||||
schema = (extractor.create_schema()
|
||||
# Single-label classifications
|
||||
.classification("primary_topic", ["tech", "business", "health", "sports", "politics"])
|
||||
.classification("urgency", ["immediate", "soon", "later", "not_urgent"])
|
||||
|
||||
# Multi-label classifications
|
||||
.classification("emotions",
|
||||
["happy", "sad", "angry", "surprised", "fearful", "disgusted"],
|
||||
multi_label=True,
|
||||
cls_threshold=0.4
|
||||
)
|
||||
.classification("content_flags",
|
||||
["inappropriate", "spam", "promotional", "personal_info", "financial_info"],
|
||||
multi_label=True,
|
||||
cls_threshold=0.3
|
||||
)
|
||||
)
|
||||
|
||||
text = "URGENT: I'm thrilled to announce our new product! But concerned about competitor reactions. Please keep confidential."
|
||||
results = extractor.extract(text, schema)
|
||||
print(results)
|
||||
# Expected output: {
|
||||
# 'primary_topic': 'business',
|
||||
# 'urgency': 'immediate',
|
||||
# 'emotions': ['happy', 'fearful'],
|
||||
# 'content_flags': ['promotional', 'personal_info']
|
||||
# }
|
||||
|
||||
# Another example
|
||||
text2 = "Just saw the game - absolutely devastated by the loss. Can't believe the referee's terrible decision!"
|
||||
results2 = extractor.extract(text2, schema)
|
||||
print(results2)
|
||||
# Expected output: {
|
||||
# 'primary_topic': 'sports',
|
||||
# 'urgency': 'not_urgent',
|
||||
# 'emotions': ['sad', 'angry'],
|
||||
# 'content_flags': []
|
||||
# }
|
||||
```
|
||||
|
||||
### Domain-Specific Multiple Classifications
|
||||
|
||||
```python
|
||||
# Customer support ticket classification
|
||||
support_schema = (extractor.create_schema()
|
||||
.classification("ticket_type",
|
||||
["technical_issue", "billing", "feature_request", "bug_report", "other"])
|
||||
.classification("priority",
|
||||
["critical", "high", "medium", "low"],
|
||||
cls_threshold=0.7
|
||||
)
|
||||
.classification("product_area",
|
||||
{
|
||||
"authentication": "Login, passwords, security",
|
||||
"payment": "Payment processing, subscriptions",
|
||||
"ui": "User interface, design issues",
|
||||
"performance": "Speed, loading, responsiveness",
|
||||
"data": "Data loss, corruption, sync issues"
|
||||
},
|
||||
multi_label=True,
|
||||
cls_threshold=0.5
|
||||
)
|
||||
.classification("customer_sentiment",
|
||||
["very_satisfied", "satisfied", "neutral", "frustrated", "very_frustrated"],
|
||||
cls_threshold=0.6
|
||||
)
|
||||
.classification("requires_action",
|
||||
["immediate_response", "investigation_needed", "waiting_customer", "resolved"],
|
||||
multi_label=True
|
||||
)
|
||||
)
|
||||
|
||||
ticket_text = """
|
||||
Subject: Cannot login - Urgent!
|
||||
|
||||
I've been trying to login for the past hour but keep getting error messages.
|
||||
This is critical as I need to process payments for my customers today.
|
||||
The page just keeps spinning and then times out. I'm extremely frustrated
|
||||
as this is costing me business. Please fix this immediately!
|
||||
"""
|
||||
|
||||
results = extractor.extract(ticket_text, support_schema)
|
||||
print(results)
|
||||
# Expected output: {
|
||||
# 'ticket_type': 'technical_issue',
|
||||
# 'priority': 'critical',
|
||||
# 'product_area': ['authentication', 'payment', 'performance'],
|
||||
# 'customer_sentiment': 'very_frustrated',
|
||||
# 'requires_action': ['immediate_response', 'investigation_needed']
|
||||
# }
|
||||
|
||||
# Another support ticket example
|
||||
ticket_text2 = """
|
||||
Hi team,
|
||||
|
||||
Thanks for the great product! I was wondering if you could add a dark mode feature?
|
||||
It would really help with eye strain during late night work sessions.
|
||||
|
||||
Best regards,
|
||||
Happy Customer
|
||||
"""
|
||||
|
||||
results2 = extractor.extract(ticket_text2, support_schema)
|
||||
print(results2)
|
||||
# Expected output: {
|
||||
# 'ticket_type': 'feature_request',
|
||||
# 'priority': 'low',
|
||||
# 'product_area': ['ui'],
|
||||
# 'customer_sentiment': 'satisfied',
|
||||
# 'requires_action': ['waiting_customer']
|
||||
# }
|
||||
```
|
||||
|
||||
### Sequential Classification with Dependencies
|
||||
|
||||
```python
|
||||
# Email routing and handling classification
|
||||
email_schema = (extractor.create_schema()
|
||||
# Primary classification
|
||||
.classification("email_category",
|
||||
["sales", "support", "hr", "legal", "general"],
|
||||
cls_threshold=0.6
|
||||
)
|
||||
|
||||
# Secondary classifications based on context
|
||||
.classification("sales_stage",
|
||||
["lead", "qualified", "proposal", "negotiation", "closed"],
|
||||
cls_threshold=0.5
|
||||
)
|
||||
.classification("support_type",
|
||||
["pre_sales", "technical", "account", "billing"],
|
||||
cls_threshold=0.5
|
||||
)
|
||||
|
||||
# Action classifications
|
||||
.classification("required_action",
|
||||
["reply_needed", "forward_to_team", "schedule_meeting", "no_action"],
|
||||
multi_label=True,
|
||||
cls_threshold=0.4
|
||||
)
|
||||
.classification("response_timeframe",
|
||||
["within_1_hour", "within_24_hours", "within_week", "non_urgent"],
|
||||
cls_threshold=0.6
|
||||
)
|
||||
)
|
||||
|
||||
email = """
|
||||
Hi Sales Team,
|
||||
|
||||
I'm interested in your enterprise solution. We're currently evaluating vendors
|
||||
for our upcoming project. Could we schedule a demo next week? We need to make
|
||||
a decision by month end.
|
||||
|
||||
Best regards,
|
||||
John from TechCorp
|
||||
"""
|
||||
|
||||
results = extractor.extract(email, email_schema)
|
||||
print(results)
|
||||
# Expected output: {
|
||||
# 'email_category': 'sales',
|
||||
# 'sales_stage': 'qualified',
|
||||
# 'support_type': 'pre_sales',
|
||||
# 'required_action': ['reply_needed', 'schedule_meeting'],
|
||||
# 'response_timeframe': 'within_24_hours'
|
||||
# }
|
||||
|
||||
# HR email example
|
||||
email2 = """
|
||||
Dear HR Department,
|
||||
|
||||
I need to update my tax withholding information. Could someone please send me
|
||||
the necessary forms? This is somewhat urgent as I need this changed before the
|
||||
next payroll cycle.
|
||||
|
||||
Thank you,
|
||||
Sarah
|
||||
"""
|
||||
|
||||
results2 = extractor.extract(email2, email_schema)
|
||||
print(results2)
|
||||
# Expected output: {
|
||||
# 'email_category': 'hr',
|
||||
# 'sales_stage': 'lead', # May have noise in non-sales emails
|
||||
# 'support_type': 'account',
|
||||
# 'required_action': ['reply_needed'],
|
||||
# 'response_timeframe': 'within_24_hours'
|
||||
# }
|
||||
```
|
||||
|
||||
### Complex Analysis with Multiple Classifications
|
||||
|
||||
```python
|
||||
# Content moderation and analysis
|
||||
content_schema = (extractor.create_schema()
|
||||
# Content classifications
|
||||
.classification("content_type",
|
||||
["article", "comment", "review", "social_post", "message"])
|
||||
.classification("primary_language",
|
||||
["english", "spanish", "french", "other"])
|
||||
|
||||
# Quality assessments
|
||||
.classification("quality_score",
|
||||
["excellent", "good", "average", "poor", "spam"],
|
||||
cls_threshold=0.7
|
||||
)
|
||||
.classification("originality",
|
||||
["original", "derivative", "duplicate", "plagiarized"],
|
||||
cls_threshold=0.8
|
||||
)
|
||||
|
||||
# Safety and compliance
|
||||
.classification("safety_flags",
|
||||
{
|
||||
"hate_speech": "Contains discriminatory or hateful content",
|
||||
"violence": "Contains violent or threatening content",
|
||||
"adult": "Contains adult or explicit content",
|
||||
"misinformation": "Contains potentially false information",
|
||||
"personal_info": "Contains personal identifying information"
|
||||
},
|
||||
multi_label=True,
|
||||
cls_threshold=0.3
|
||||
)
|
||||
|
||||
# Engagement predictions
|
||||
.classification("engagement_potential",
|
||||
["viral", "high", "medium", "low"],
|
||||
cls_threshold=0.6
|
||||
)
|
||||
.classification("audience_fit",
|
||||
["general", "professional", "academic", "youth", "senior"],
|
||||
multi_label=True,
|
||||
cls_threshold=0.5
|
||||
)
|
||||
)
|
||||
|
||||
content_text = """
|
||||
Just discovered this amazing productivity hack that doubled my output!
|
||||
Here's what I do: I wake up at 5 AM, meditate for 20 minutes, then work
|
||||
in 90-minute focused blocks. The results have been incredible. My email
|
||||
is john.doe@example.com if you want more tips!
|
||||
"""
|
||||
|
||||
results = extractor.extract(content_text, content_schema)
|
||||
print(results)
|
||||
# Expected output: {
|
||||
# 'content_type': 'social_post',
|
||||
# 'primary_language': 'english',
|
||||
# 'quality_score': 'good',
|
||||
# 'originality': 'original',
|
||||
# 'safety_flags': ['personal_info'],
|
||||
# 'engagement_potential': 'high',
|
||||
# 'audience_fit': ['general', 'professional']
|
||||
# }
|
||||
|
||||
# Review example
|
||||
review_text = """
|
||||
Worst product ever!!! Total scam! Don't buy this garbage. The company should
|
||||
be shut down for selling this junk. I'm going to report them to authorities.
|
||||
"""
|
||||
|
||||
results2 = extractor.extract(review_text, content_schema)
|
||||
print(results2)
|
||||
# Expected output: {
|
||||
# 'content_type': 'review',
|
||||
# 'primary_language': 'english',
|
||||
# 'quality_score': 'poor',
|
||||
# 'originality': 'original',
|
||||
# 'safety_flags': ['violence'], # Due to aggressive language
|
||||
# 'engagement_potential': 'low',
|
||||
# 'audience_fit': ['general']
|
||||
# }
|
||||
```
|
||||
|
||||
## Advanced Configurations
|
||||
|
||||
### Custom Thresholds
|
||||
|
||||
```python
|
||||
# High-precision classification
|
||||
schema = extractor.create_schema().classification(
|
||||
"is_spam",
|
||||
["spam", "not_spam"],
|
||||
cls_threshold=0.9 # Very high confidence required
|
||||
)
|
||||
|
||||
text = "Congratulations! You've won $1,000,000! Click here to claim your prize now!"
|
||||
results = extractor.extract(text, schema)
|
||||
print(results)
|
||||
# Expected output: {'is_spam': 'spam'}
|
||||
|
||||
# Different thresholds for different tasks
|
||||
schema = (extractor.create_schema()
|
||||
.classification("priority", ["urgent", "high", "normal", "low"], cls_threshold=0.8)
|
||||
.classification("department", ["sales", "support", "billing", "other"], cls_threshold=0.5)
|
||||
)
|
||||
|
||||
text = "URGENT: Customer threatening to cancel $50k contract due to billing error"
|
||||
results = extractor.extract(text, schema)
|
||||
print(results)
|
||||
# Expected output: {
|
||||
# 'priority': 'urgent',
|
||||
# 'department': 'billing'
|
||||
# }
|
||||
```
|
||||
|
||||
### Custom Activation Functions
|
||||
|
||||
```python
|
||||
# Force specific activation
|
||||
schema = extractor.create_schema().classification(
|
||||
"category",
|
||||
["A", "B", "C", "D"],
|
||||
class_act="softmax" # Options: "sigmoid", "softmax", "auto"
|
||||
)
|
||||
|
||||
text = "This clearly belongs to category B based on the criteria."
|
||||
results = extractor.extract(text, schema)
|
||||
print(results)
|
||||
# Expected output: {'category': 'B'}
|
||||
```
|
||||
|
||||
### Complex Multi-Label Example
|
||||
|
||||
```python
|
||||
# Email classification system
|
||||
schema = extractor.create_schema().classification(
|
||||
"email_tags",
|
||||
{
|
||||
"action_required": "Email requires recipient to take action",
|
||||
"meeting_request": "Email contains meeting invitation or scheduling",
|
||||
"project_update": "Email contains project status or updates",
|
||||
"urgent": "Email marked as urgent or time-sensitive",
|
||||
"question": "Email contains questions requiring answers",
|
||||
"fyi": "Informational email requiring no action"
|
||||
},
|
||||
multi_label=True,
|
||||
cls_threshold=0.35
|
||||
)
|
||||
|
||||
email_text = """
|
||||
Hi team,
|
||||
|
||||
Quick update on Project Alpha: We're ahead of schedule!
|
||||
|
||||
However, I need your input on the design mockups by EOD tomorrow.
|
||||
Can we schedule a 30-min call this week to discuss?
|
||||
|
||||
This is quite urgent as the client is waiting.
|
||||
|
||||
Best,
|
||||
Sarah
|
||||
"""
|
||||
|
||||
results = extractor.extract(email_text, schema)
|
||||
print(results)
|
||||
# Expected output: {
|
||||
# 'email_tags': ['action_required', 'meeting_request', 'project_update', 'urgent', 'question']
|
||||
# }
|
||||
|
||||
# FYI email example
|
||||
email_text2 = """
|
||||
Team,
|
||||
|
||||
Just wanted to let everyone know that I'll be out of office next Monday for a
|
||||
doctor's appointment. I'll be back Tuesday morning.
|
||||
|
||||
Thanks,
|
||||
Mark
|
||||
"""
|
||||
|
||||
results2 = extractor.extract(email_text2, schema)
|
||||
print(results2)
|
||||
# Expected output: {
|
||||
# 'email_tags': ['fyi']
|
||||
# }
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use Descriptions**: Always provide label descriptions when possible
|
||||
```python
|
||||
# Good - with descriptions
|
||||
schema = extractor.create_schema().classification(
|
||||
"intent",
|
||||
{
|
||||
"purchase": "User wants to buy a product",
|
||||
"return": "User wants to return a product",
|
||||
"inquiry": "User asking for information"
|
||||
}
|
||||
)
|
||||
|
||||
# Less effective - no context
|
||||
schema = extractor.create_schema().classification(
|
||||
"intent",
|
||||
["purchase", "return", "inquiry"]
|
||||
)
|
||||
```
|
||||
|
||||
2. **Adjust Thresholds**: Lower thresholds for multi-label (0.3-0.5), higher for single-label (0.5-0.7)
|
||||
|
||||
3. **Multi-Label Strategy**: Use multi-label when categories aren't mutually exclusive
|
||||
```python
|
||||
# Good use of multi-label
|
||||
schema = extractor.create_schema().classification(
|
||||
"product_features",
|
||||
["waterproof", "wireless", "rechargeable", "portable"],
|
||||
multi_label=True
|
||||
)
|
||||
|
||||
# Should be single-label
|
||||
schema = extractor.create_schema().classification(
|
||||
"size",
|
||||
["small", "medium", "large"],
|
||||
multi_label=False # Sizes are mutually exclusive
|
||||
)
|
||||
```
|
||||
|
||||
4. **Test with Real Examples**: Always test with actual text samples from your domain
|
||||
|
||||
## Common Use Cases
|
||||
|
||||
- **Sentiment Analysis**: Customer feedback, reviews, social media
|
||||
- **Intent Classification**: Chatbots, customer service routing
|
||||
- **Document Classification**: Email filtering, document management
|
||||
- **Content Moderation**: Toxic content, spam detection
|
||||
- **Topic Classification**: News categorization, content tagging
|
||||
973
packages/GLiNER2/tutorial/10-lora_adapters.md
Normal file
973
packages/GLiNER2/tutorial/10-lora_adapters.md
Normal file
@ -0,0 +1,973 @@
|
||||
# Tutorial 10: LoRA Adapters - Multi-Domain Inference
|
||||
|
||||
## Table of Contents
|
||||
1. [Introduction](#introduction)
|
||||
2. [Why Use LoRA Adapters?](#why-use-lora-adapters)
|
||||
3. [Training Your First Adapter](#training-your-first-adapter)
|
||||
4. [Training Multiple Domain Adapters](#training-multiple-domain-adapters)
|
||||
5. [Loading and Swapping Adapters](#loading-and-swapping-adapters)
|
||||
6. [Real-World Use Cases](#real-world-use-cases)
|
||||
7. [Best Practices](#best-practices)
|
||||
8. [Troubleshooting](#troubleshooting)
|
||||
|
||||
## Introduction
|
||||
|
||||
LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning technique that allows you to train specialized adapters for different domains without modifying the base model. This enables:
|
||||
|
||||
- **Fast domain switching**: Swap between domains in milliseconds
|
||||
- **Minimal storage**: Adapters are ~2-10 MB vs ~100-500 MB for full models
|
||||
- **Domain specialization**: Train separate adapters for legal, medical, financial, etc.
|
||||
- **Easy deployment**: Keep one base model + multiple lightweight adapters
|
||||
|
||||
## Why Use LoRA Adapters?
|
||||
|
||||
### Memory Efficiency
|
||||
|
||||
```
|
||||
Full Model Fine-tuning:
|
||||
- Legal model: 450 MB
|
||||
- Medical model: 450 MB
|
||||
- Financial model: 450 MB
|
||||
Total: 1.35 GB
|
||||
|
||||
LoRA Adapters:
|
||||
- Base model: 450 MB
|
||||
- Legal adapter: 5 MB
|
||||
- Medical adapter: 5 MB
|
||||
- Financial adapter: 5 MB
|
||||
Total: 465 MB (65% less!)
|
||||
```
|
||||
|
||||
### Fast Training
|
||||
|
||||
LoRA adapters train **2-3x faster** than full fine-tuning because:
|
||||
- Only ~1-5% of parameters are trainable
|
||||
- Smaller gradient computations
|
||||
- Less GPU memory required
|
||||
|
||||
### Easy Multi-Domain Inference
|
||||
|
||||
```python
|
||||
# One base model, multiple domains
|
||||
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
|
||||
|
||||
# Legal domain
|
||||
model.load_adapter("./legal_adapter")
|
||||
legal_results = model.extract_entities(legal_text, ["company", "law"])
|
||||
|
||||
# Medical domain (swap in <1 second)
|
||||
model.load_adapter("./medical_adapter")
|
||||
medical_results = model.extract_entities(medical_text, ["disease", "drug"])
|
||||
```
|
||||
|
||||
## Training Your First Adapter
|
||||
|
||||
### Step 1: Prepare Domain-Specific Data
|
||||
|
||||
```python
|
||||
from gliner2.training.data import InputExample
|
||||
|
||||
# Legal domain examples
|
||||
legal_examples = [
|
||||
InputExample(
|
||||
text="Apple Inc. filed a lawsuit against Samsung Electronics.",
|
||||
entities={"company": ["Apple Inc.", "Samsung Electronics"]}
|
||||
),
|
||||
InputExample(
|
||||
text="The plaintiff Google LLC accused Microsoft Corporation of patent infringement.",
|
||||
entities={"company": ["Google LLC", "Microsoft Corporation"]}
|
||||
),
|
||||
InputExample(
|
||||
text="Tesla Motors settled the case with the Securities and Exchange Commission.",
|
||||
entities={
|
||||
"company": ["Tesla Motors"],
|
||||
"organization": ["Securities and Exchange Commission"]
|
||||
}
|
||||
),
|
||||
# Add 100-1000+ examples for best results
|
||||
]
|
||||
```
|
||||
|
||||
### Step 2: Configure LoRA Training
|
||||
|
||||
```python
|
||||
from gliner2 import GLiNER2
|
||||
from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig
|
||||
|
||||
# LoRA configuration
|
||||
config = TrainingConfig(
|
||||
output_dir="./legal_adapter",
|
||||
experiment_name="legal_domain",
|
||||
|
||||
# Training parameters
|
||||
num_epochs=10,
|
||||
batch_size=8,
|
||||
gradient_accumulation_steps=2,
|
||||
encoder_lr=1e-5,
|
||||
task_lr=5e-4,
|
||||
|
||||
# LoRA settings
|
||||
use_lora=True, # Enable LoRA
|
||||
lora_r=8, # Rank (4, 8, 16, 32)
|
||||
lora_alpha=16.0, # Scaling factor (usually 2*r)
|
||||
lora_dropout=0.0, # Dropout for LoRA layers
|
||||
lora_target_modules=["encoder"], # Apply to all encoder layers (query, key, value, dense)
|
||||
save_adapter_only=True, # Save only adapter (not full model)
|
||||
|
||||
# Optimization
|
||||
eval_strategy="epoch", # Evaluates and saves at end of each epoch
|
||||
eval_steps=500, # Used when eval_strategy="steps"
|
||||
logging_steps=50,
|
||||
fp16=True, # Use mixed precision if GPU available
|
||||
)
|
||||
```
|
||||
|
||||
### Step 3: Train the Adapter
|
||||
|
||||
```python
|
||||
# Load base model
|
||||
base_model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
|
||||
|
||||
# Create trainer
|
||||
trainer = GLiNER2Trainer(model=base_model, config=config)
|
||||
|
||||
# Train adapter
|
||||
trainer.train(train_data=legal_examples)
|
||||
|
||||
# Adapter automatically saved to ./legal_adapter/final/
|
||||
```
|
||||
|
||||
**Training output:**
|
||||
```
|
||||
🔧 LoRA Configuration
|
||||
======================================================================
|
||||
Enabled : True
|
||||
Rank (r) : 8
|
||||
Alpha : 16.0
|
||||
Scaling (α/r) : 2.0000
|
||||
Dropout : 0.0
|
||||
Target modules : query, key, value, dense
|
||||
LoRA layers : 144
|
||||
----------------------------------------------------------------------
|
||||
Trainable params : 1,327,104 / 124,442,368 (1.07%)
|
||||
Memory savings : ~98.9% fewer gradients
|
||||
======================================================================
|
||||
|
||||
***** Running Training *****
|
||||
Num examples = 1000
|
||||
Num epochs = 10
|
||||
Batch size = 8
|
||||
Effective batch size = 16
|
||||
Total optimization steps = 625
|
||||
LoRA enabled: 1,327,104 trainable / 124,442,368 total (1.07%)
|
||||
```
|
||||
|
||||
## Training Multiple Domain Adapters
|
||||
|
||||
Let's train adapters for three different domains: **Legal**, **Medical**, and **Customer Support**.
|
||||
|
||||
### Complete Multi-Domain Training Script
|
||||
|
||||
```python
|
||||
from gliner2 import GLiNER2
|
||||
from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig
|
||||
from gliner2.training.data import InputExample
|
||||
|
||||
# ============================================================================
|
||||
# Define Domain Data
|
||||
# ============================================================================
|
||||
|
||||
# Legal domain
|
||||
legal_examples = [
|
||||
InputExample(
|
||||
text="Apple Inc. filed a lawsuit against Samsung Electronics.",
|
||||
entities={"company": ["Apple Inc.", "Samsung Electronics"]}
|
||||
),
|
||||
InputExample(
|
||||
text="The plaintiff Google LLC accused Microsoft Corporation of patent infringement.",
|
||||
entities={"company": ["Google LLC", "Microsoft Corporation"]}
|
||||
),
|
||||
# Add more examples...
|
||||
]
|
||||
|
||||
# Medical domain
|
||||
medical_examples = [
|
||||
InputExample(
|
||||
text="Patient diagnosed with Type 2 Diabetes and Hypertension.",
|
||||
entities={"disease": ["Type 2 Diabetes", "Hypertension"]}
|
||||
),
|
||||
InputExample(
|
||||
text="Prescribed Metformin 500mg twice daily and Lisinopril 10mg once daily.",
|
||||
entities={
|
||||
"drug": ["Metformin", "Lisinopril"],
|
||||
"dosage": ["500mg", "10mg"]
|
||||
}
|
||||
),
|
||||
# Add more examples...
|
||||
]
|
||||
|
||||
# Customer support domain
|
||||
support_examples = [
|
||||
InputExample(
|
||||
text="Customer John Smith reported issue with Order #12345.",
|
||||
entities={
|
||||
"customer": ["John Smith"],
|
||||
"order_id": ["Order #12345"]
|
||||
}
|
||||
),
|
||||
InputExample(
|
||||
text="Refund of $99.99 processed for Order #98765 on 2024-01-15.",
|
||||
entities={
|
||||
"order_id": ["Order #98765"],
|
||||
"amount": ["$99.99"],
|
||||
"date": ["2024-01-15"]
|
||||
}
|
||||
),
|
||||
# Add more examples...
|
||||
]
|
||||
|
||||
# ============================================================================
|
||||
# Training Function
|
||||
# ============================================================================
|
||||
|
||||
def train_domain_adapter(
|
||||
base_model_name: str,
|
||||
examples: list,
|
||||
domain_name: str,
|
||||
output_dir: str = "./adapters"
|
||||
):
|
||||
"""Train a LoRA adapter for a specific domain."""
|
||||
|
||||
adapter_path = f"{output_dir}/{domain_name}_adapter"
|
||||
|
||||
config = TrainingConfig(
|
||||
output_dir=adapter_path,
|
||||
experiment_name=f"{domain_name}_domain",
|
||||
|
||||
# Training
|
||||
num_epochs=10,
|
||||
batch_size=8,
|
||||
gradient_accumulation_steps=2,
|
||||
encoder_lr=1e-5,
|
||||
task_lr=5e-4,
|
||||
|
||||
# LoRA
|
||||
use_lora=True,
|
||||
lora_r=8,
|
||||
lora_alpha=16.0,
|
||||
lora_dropout=0.0,
|
||||
lora_target_modules=["encoder"], # All encoder layers
|
||||
save_adapter_only=True,
|
||||
|
||||
# Logging & Checkpointing
|
||||
eval_strategy="no", # Set to "epoch" or "steps" if you have validation set
|
||||
eval_steps=500, # Used when eval_strategy="steps"
|
||||
logging_steps=50,
|
||||
fp16=True,
|
||||
)
|
||||
|
||||
# Load base model
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Training {domain_name.upper()} adapter")
|
||||
print(f"{'='*60}")
|
||||
|
||||
model = GLiNER2.from_pretrained(base_model_name)
|
||||
trainer = GLiNER2Trainer(model=model, config=config)
|
||||
|
||||
# Train
|
||||
results = trainer.train(train_data=examples)
|
||||
|
||||
print(f"\n✅ {domain_name.capitalize()} adapter trained!")
|
||||
print(f"📁 Saved to: {adapter_path}/final/")
|
||||
print(f"⏱️ Training time: {results['total_time_seconds']:.2f}s")
|
||||
|
||||
return f"{adapter_path}/final"
|
||||
|
||||
# ============================================================================
|
||||
# Train All Adapters
|
||||
# ============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
BASE_MODEL = "fastino/gliner2-base-v1"
|
||||
|
||||
# Train adapters for each domain
|
||||
legal_adapter_path = train_domain_adapter(
|
||||
BASE_MODEL, legal_examples, "legal"
|
||||
)
|
||||
|
||||
medical_adapter_path = train_domain_adapter(
|
||||
BASE_MODEL, medical_examples, "medical"
|
||||
)
|
||||
|
||||
support_adapter_path = train_domain_adapter(
|
||||
BASE_MODEL, support_examples, "support"
|
||||
)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("🎉 All adapters trained successfully!")
|
||||
print("="*60)
|
||||
print(f"Legal adapter: {legal_adapter_path}")
|
||||
print(f"Medical adapter: {medical_adapter_path}")
|
||||
print(f"Support adapter: {support_adapter_path}")
|
||||
```
|
||||
|
||||
## Loading and Swapping Adapters
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from gliner2 import GLiNER2
|
||||
|
||||
# Load base model once
|
||||
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
|
||||
|
||||
# Load legal adapter
|
||||
model.load_adapter("./adapters/legal_adapter/final")
|
||||
|
||||
# Use the model
|
||||
result = model.extract_entities(
|
||||
"Apple Inc. sued Samsung over patent rights.",
|
||||
["company", "legal_action"]
|
||||
)
|
||||
print(result)
|
||||
```
|
||||
|
||||
### Swapping Between Adapters
|
||||
|
||||
```python
|
||||
# Load base model
|
||||
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
|
||||
|
||||
# Legal domain
|
||||
print("📋 Legal Analysis:")
|
||||
model.load_adapter("./adapters/legal_adapter/final")
|
||||
legal_text = "Google LLC filed a complaint against Oracle Corporation."
|
||||
legal_result = model.extract_entities(legal_text, ["company", "legal_action"])
|
||||
print(f" {legal_result}")
|
||||
|
||||
# Swap to medical domain
|
||||
print("\n🏥 Medical Analysis:")
|
||||
model.load_adapter("./adapters/medical_adapter/final")
|
||||
medical_text = "Patient presents with Pneumonia and was prescribed Amoxicillin."
|
||||
medical_result = model.extract_entities(medical_text, ["disease", "drug"])
|
||||
print(f" {medical_result}")
|
||||
|
||||
# Swap to support domain
|
||||
print("\n💬 Support Analysis:")
|
||||
model.load_adapter("./adapters/support_adapter/final")
|
||||
support_text = "Customer reported Order #12345 not delivered on time."
|
||||
support_result = model.extract_entities(support_text, ["order_id", "issue"])
|
||||
print(f" {support_result}")
|
||||
|
||||
# Use base model without adapter
|
||||
print("\n🔧 Base Model (no adapter):")
|
||||
model.unload_adapter()
|
||||
base_result = model.extract_entities("Some generic text", ["entity"])
|
||||
print(f" {base_result}")
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```
|
||||
📋 Legal Analysis:
|
||||
{'entities': [{'text': 'Google LLC', 'label': 'company', ...},
|
||||
{'text': 'Oracle Corporation', 'label': 'company', ...}]}
|
||||
|
||||
🏥 Medical Analysis:
|
||||
{'entities': [{'text': 'Pneumonia', 'label': 'disease', ...},
|
||||
{'text': 'Amoxicillin', 'label': 'drug', ...}]}
|
||||
|
||||
💬 Support Analysis:
|
||||
{'entities': [{'text': 'Order #12345', 'label': 'order_id', ...}]}
|
||||
|
||||
🔧 Base Model (no adapter):
|
||||
{'entities': [{'text': 'text', 'label': 'entity', ...}]}
|
||||
```
|
||||
|
||||
### Batch Processing with Adapter Swapping
|
||||
|
||||
```python
|
||||
def process_documents_by_domain(model, documents_by_domain, adapters):
|
||||
"""
|
||||
Process multiple documents across different domains efficiently.
|
||||
|
||||
Args:
|
||||
model: Base GLiNER2 model
|
||||
documents_by_domain: Dict[domain_name, List[document_text]]
|
||||
adapters: Dict[domain_name, adapter_path]
|
||||
|
||||
Returns:
|
||||
Dict[domain_name, List[results]]
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for domain, documents in documents_by_domain.items():
|
||||
print(f"Processing {domain} domain ({len(documents)} documents)...")
|
||||
|
||||
# Load domain-specific adapter
|
||||
model.load_adapter(adapters[domain])
|
||||
|
||||
# Process all documents for this domain
|
||||
domain_results = []
|
||||
for doc in documents:
|
||||
result = model.extract_entities(doc, get_entity_types(domain))
|
||||
domain_results.append(result)
|
||||
|
||||
results[domain] = domain_results
|
||||
|
||||
return results
|
||||
|
||||
def get_entity_types(domain):
|
||||
"""Get entity types for each domain."""
|
||||
types = {
|
||||
"legal": ["company", "person", "law", "legal_action"],
|
||||
"medical": ["disease", "drug", "symptom", "procedure"],
|
||||
"support": ["customer", "order_id", "product", "issue"]
|
||||
}
|
||||
return types.get(domain, ["entity"])
|
||||
|
||||
# Example usage
|
||||
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
|
||||
|
||||
documents_by_domain = {
|
||||
"legal": [
|
||||
"Apple Inc. filed suit against Samsung.",
|
||||
"Microsoft acquired LinkedIn for $26B.",
|
||||
],
|
||||
"medical": [
|
||||
"Patient has Type 2 Diabetes.",
|
||||
"Prescribed Metformin 500mg daily.",
|
||||
],
|
||||
"support": [
|
||||
"Issue with Order #12345 reported.",
|
||||
"Refund processed for Order #98765.",
|
||||
]
|
||||
}
|
||||
|
||||
adapters = {
|
||||
"legal": "./adapters/legal_adapter/final",
|
||||
"medical": "./adapters/medical_adapter/final",
|
||||
"support": "./adapters/support_adapter/final",
|
||||
}
|
||||
|
||||
results = process_documents_by_domain(model, documents_by_domain, adapters)
|
||||
|
||||
# Results organized by domain
|
||||
for domain, domain_results in results.items():
|
||||
print(f"\n{domain.upper()} Results:")
|
||||
for i, result in enumerate(domain_results, 1):
|
||||
print(f" Document {i}: {len(result['entities'])} entities found")
|
||||
```
|
||||
|
||||
## Real-World Use Cases
|
||||
|
||||
### Use Case 1: Multi-Tenant SaaS Platform
|
||||
|
||||
```python
|
||||
class MultiTenantEntityExtractor:
|
||||
"""Entity extraction service for multi-tenant platform."""
|
||||
|
||||
def __init__(self, base_model_name: str, tenant_adapters: dict):
|
||||
"""
|
||||
Args:
|
||||
base_model_name: Path to base model
|
||||
tenant_adapters: Dict mapping tenant_id to adapter_path
|
||||
"""
|
||||
self.model = GLiNER2.from_pretrained(base_model_name)
|
||||
self.tenant_adapters = tenant_adapters
|
||||
self.current_tenant = None
|
||||
|
||||
def extract_for_tenant(self, tenant_id: str, text: str, entity_types: list):
|
||||
"""Extract entities for specific tenant."""
|
||||
# Load tenant-specific adapter if needed
|
||||
if self.current_tenant != tenant_id:
|
||||
adapter_path = self.tenant_adapters.get(tenant_id)
|
||||
if adapter_path:
|
||||
self.model.load_adapter(adapter_path)
|
||||
else:
|
||||
self.model.unload_adapter() # Use base model
|
||||
self.current_tenant = tenant_id
|
||||
|
||||
return self.model.extract_entities(text, entity_types)
|
||||
|
||||
# Setup
|
||||
extractor = MultiTenantEntityExtractor(
|
||||
base_model_name="fastino/gliner2-base-v1",
|
||||
tenant_adapters={
|
||||
"legal_firm_123": "./adapters/legal_adapter/final",
|
||||
"hospital_456": "./adapters/medical_adapter/final",
|
||||
"ecommerce_789": "./adapters/support_adapter/final",
|
||||
}
|
||||
)
|
||||
|
||||
# Usage
|
||||
legal_result = extractor.extract_for_tenant(
|
||||
"legal_firm_123",
|
||||
"Apple sued Samsung",
|
||||
["company"]
|
||||
)
|
||||
|
||||
medical_result = extractor.extract_for_tenant(
|
||||
"hospital_456",
|
||||
"Patient has diabetes",
|
||||
["disease"]
|
||||
)
|
||||
```
|
||||
|
||||
### Use Case 2: Document Classification Pipeline
|
||||
|
||||
```python
|
||||
def classify_and_extract(document: str, model: GLiNER2, adapters: dict):
|
||||
"""
|
||||
Classify document type and extract relevant entities.
|
||||
|
||||
1. Classify document type using base model
|
||||
2. Load appropriate domain adapter
|
||||
3. Extract domain-specific entities
|
||||
"""
|
||||
# Step 1: Classify document type
|
||||
doc_type_result = model.extract_entities(
|
||||
document,
|
||||
["legal_document", "medical_record", "support_ticket", "financial_report"]
|
||||
)
|
||||
|
||||
# Determine document type
|
||||
if doc_type_result['entities']:
|
||||
doc_type = doc_type_result['entities'][0]['label']
|
||||
doc_type = doc_type.replace("_document", "").replace("_record", "").replace("_ticket", "").replace("_report", "")
|
||||
else:
|
||||
doc_type = "general"
|
||||
|
||||
# Step 2: Load appropriate adapter
|
||||
adapter_mapping = {
|
||||
"legal": adapters.get("legal"),
|
||||
"medical": adapters.get("medical"),
|
||||
"support": adapters.get("support"),
|
||||
"financial": adapters.get("financial"),
|
||||
}
|
||||
|
||||
if doc_type in adapter_mapping and adapter_mapping[doc_type]:
|
||||
model.load_adapter(adapter_mapping[doc_type])
|
||||
|
||||
# Step 3: Extract domain-specific entities
|
||||
entity_types = {
|
||||
"legal": ["company", "person", "law", "legal_action"],
|
||||
"medical": ["disease", "drug", "symptom", "procedure", "dosage"],
|
||||
"support": ["customer", "order_id", "product", "issue", "status"],
|
||||
"financial": ["company", "amount", "date", "stock_symbol"],
|
||||
}
|
||||
|
||||
entities = model.extract_entities(
|
||||
document,
|
||||
entity_types.get(doc_type, ["entity"])
|
||||
)
|
||||
|
||||
return {
|
||||
"document_type": doc_type,
|
||||
"entities": entities['entities']
|
||||
}
|
||||
|
||||
# Usage
|
||||
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
|
||||
|
||||
adapters = {
|
||||
"legal": "./adapters/legal_adapter/final",
|
||||
"medical": "./adapters/medical_adapter/final",
|
||||
"support": "./adapters/support_adapter/final",
|
||||
}
|
||||
|
||||
document = "Patient John Smith diagnosed with Type 2 Diabetes on 2024-01-15."
|
||||
result = classify_and_extract(document, model, adapters)
|
||||
|
||||
print(f"Document Type: {result['document_type']}")
|
||||
print(f"Entities: {result['entities']}")
|
||||
```
|
||||
|
||||
### Use Case 3: A/B Testing Adapters
|
||||
|
||||
```python
|
||||
import random
|
||||
|
||||
class AdapterABTester:
|
||||
"""A/B test different adapter versions."""
|
||||
|
||||
def __init__(self, base_model_name: str, adapter_variants: dict):
|
||||
"""
|
||||
Args:
|
||||
adapter_variants: {"v1": path1, "v2": path2, ...}
|
||||
"""
|
||||
self.model = GLiNER2.from_pretrained(base_model_name)
|
||||
self.adapter_variants = adapter_variants
|
||||
self.results = {variant: [] for variant in adapter_variants}
|
||||
|
||||
def test_sample(self, text: str, entity_types: list, true_entities: list):
|
||||
"""Test a sample with all adapter variants."""
|
||||
sample_results = {}
|
||||
|
||||
for variant, adapter_path in self.adapter_variants.items():
|
||||
# Load variant
|
||||
self.model.load_adapter(adapter_path)
|
||||
|
||||
# Get predictions
|
||||
pred = self.model.extract_entities(text, entity_types)
|
||||
|
||||
# Compute metrics
|
||||
f1 = self.compute_f1(pred['entities'], true_entities)
|
||||
|
||||
sample_results[variant] = {
|
||||
"predictions": pred['entities'],
|
||||
"f1_score": f1
|
||||
}
|
||||
|
||||
self.results[variant].append(f1)
|
||||
|
||||
return sample_results
|
||||
|
||||
def compute_f1(self, predicted, ground_truth):
|
||||
"""Simple F1 computation (simplified for demo)."""
|
||||
pred_set = {(e['text'], e['label']) for e in predicted}
|
||||
true_set = {(e['text'], e['label']) for e in ground_truth}
|
||||
|
||||
if not pred_set and not true_set:
|
||||
return 1.0
|
||||
if not pred_set or not true_set:
|
||||
return 0.0
|
||||
|
||||
tp = len(pred_set & true_set)
|
||||
precision = tp / len(pred_set) if pred_set else 0
|
||||
recall = tp / len(true_set) if true_set else 0
|
||||
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
def get_summary(self):
|
||||
"""Get A/B test summary."""
|
||||
summary = {}
|
||||
for variant, scores in self.results.items():
|
||||
if scores:
|
||||
summary[variant] = {
|
||||
"avg_f1": sum(scores) / len(scores),
|
||||
"samples": len(scores)
|
||||
}
|
||||
return summary
|
||||
|
||||
# Usage
|
||||
tester = AdapterABTester(
|
||||
base_model_name="fastino/gliner2-base-v1",
|
||||
adapter_variants={
|
||||
"v1_r4": "./adapters/legal_v1_r4/final",
|
||||
"v2_r8": "./adapters/legal_v2_r8/final",
|
||||
"v3_r16": "./adapters/legal_v3_r16/final",
|
||||
}
|
||||
)
|
||||
|
||||
# Test samples
|
||||
test_samples = [
|
||||
{
|
||||
"text": "Apple Inc. sued Samsung Electronics.",
|
||||
"entity_types": ["company"],
|
||||
"true_entities": [
|
||||
{"text": "Apple Inc.", "label": "company"},
|
||||
{"text": "Samsung Electronics", "label": "company"}
|
||||
]
|
||||
},
|
||||
# More samples...
|
||||
]
|
||||
|
||||
for sample in test_samples:
|
||||
results = tester.test_sample(
|
||||
sample["text"],
|
||||
sample["entity_types"],
|
||||
sample["true_entities"]
|
||||
)
|
||||
|
||||
# Get summary
|
||||
summary = tester.get_summary()
|
||||
for variant, metrics in summary.items():
|
||||
print(f"{variant}: Avg F1 = {metrics['avg_f1']:.3f} ({metrics['samples']} samples)")
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Choosing LoRA Hyperparameters
|
||||
|
||||
```python
|
||||
# Small datasets (< 1K examples)
|
||||
config = TrainingConfig(
|
||||
lora_r=4, # Lower rank = fewer parameters
|
||||
lora_alpha=8.0, # alpha = 2 * r
|
||||
num_epochs=10,
|
||||
)
|
||||
|
||||
# Medium datasets (1K-10K examples)
|
||||
config = TrainingConfig(
|
||||
lora_r=8, # Standard rank
|
||||
lora_alpha=16.0,
|
||||
num_epochs=5,
|
||||
)
|
||||
|
||||
# Large datasets (> 10K examples)
|
||||
config = TrainingConfig(
|
||||
lora_r=16, # Higher rank = more capacity
|
||||
lora_alpha=32.0,
|
||||
num_epochs=3,
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Target Module Selection
|
||||
|
||||
**Understanding Module Groups:**
|
||||
|
||||
GLiNER2 supports fine-grained control over which layers receive LoRA adaptation:
|
||||
|
||||
```python
|
||||
# Option 1: Encoder only - all layers (query, key, value, dense)
|
||||
# Use case: General domain adaptation, good starting point
|
||||
# Memory: Moderate (~1-2% of model parameters)
|
||||
lora_target_modules=["encoder"]
|
||||
|
||||
# Option 2: Encoder - attention layers only
|
||||
# Use case: Very memory-constrained scenarios
|
||||
# Memory: Low (~0.5-1% of model parameters)
|
||||
lora_target_modules=["encoder.query", "encoder.key", "encoder.value"]
|
||||
|
||||
# Option 3: Encoder - FFN layers only
|
||||
# Use case: Alternative to attention-only, sometimes better for certain tasks
|
||||
# Memory: Low (~0.5-1% of model parameters)
|
||||
lora_target_modules=["encoder.dense"]
|
||||
|
||||
# Option 4: Encoder + task heads
|
||||
# Use case: When you want to adapt both representation and task-specific layers
|
||||
# Memory: Moderate-High (~2-4% of model parameters)
|
||||
lora_target_modules=["encoder", "span_rep", "classifier"]
|
||||
|
||||
# Option 5: All modules (DEFAULT)
|
||||
# Use case: Maximum adaptation capacity, best performance
|
||||
# Memory: High (~3-5% of model parameters)
|
||||
lora_target_modules=["encoder", "span_rep", "classifier", "count_embed", "count_pred"]
|
||||
```
|
||||
|
||||
**Recommendations:**
|
||||
- **Start with encoder only** (`["encoder"]`) for most tasks
|
||||
- **Add task heads** if performance is insufficient
|
||||
- **Use attention-only** for extreme memory constraints
|
||||
- **Use all modules** (default) when you need maximum performance
|
||||
|
||||
### 3. Adapter Organization
|
||||
|
||||
```
|
||||
project/
|
||||
├── base_model/
|
||||
│ └── gliner2-base-v1/
|
||||
├── adapters/
|
||||
│ ├── legal/
|
||||
│ │ ├── v1_r8/
|
||||
│ │ │ └── final/
|
||||
│ │ └── v2_r16/
|
||||
│ │ └── final/
|
||||
│ ├── medical/
|
||||
│ │ └── final/
|
||||
│ └── support/
|
||||
│ └── final/
|
||||
└── scripts/
|
||||
├── train_adapters.py
|
||||
└── evaluate_adapters.py
|
||||
```
|
||||
|
||||
### 4. Version Control for Adapters
|
||||
|
||||
```python
|
||||
# adapter_metadata.json
|
||||
{
|
||||
"legal_v1": {
|
||||
"path": "./adapters/legal/v1_r8/final",
|
||||
"base_model": "fastino/gliner2-base-v1",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16.0,
|
||||
"trained_on": "2024-01-15",
|
||||
"training_samples": 5000,
|
||||
"eval_f1": 0.87,
|
||||
"notes": "Initial legal domain adapter"
|
||||
},
|
||||
"legal_v2": {
|
||||
"path": "./adapters/legal/v2_r16/final",
|
||||
"base_model": "fastino/gliner2-base-v1",
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 32.0,
|
||||
"trained_on": "2024-02-01",
|
||||
"training_samples": 10000,
|
||||
"eval_f1": 0.92,
|
||||
"notes": "Improved with more data and higher rank"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Monitoring Adapter Performance
|
||||
|
||||
```python
|
||||
def evaluate_adapter(model, adapter_path, test_data):
|
||||
"""Evaluate adapter performance on test data."""
|
||||
model.load_adapter(adapter_path)
|
||||
|
||||
results = {
|
||||
"total": 0,
|
||||
"correct": 0,
|
||||
"precision_sum": 0,
|
||||
"recall_sum": 0,
|
||||
}
|
||||
|
||||
for sample in test_data:
|
||||
pred = model.extract_entities(sample["text"], sample["entity_types"])
|
||||
|
||||
# Compute metrics
|
||||
metrics = compute_metrics(pred['entities'], sample["true_entities"])
|
||||
results["total"] += 1
|
||||
results["precision_sum"] += metrics["precision"]
|
||||
results["recall_sum"] += metrics["recall"]
|
||||
|
||||
avg_precision = results["precision_sum"] / results["total"]
|
||||
avg_recall = results["recall_sum"] / results["total"]
|
||||
f1 = 2 * avg_precision * avg_recall / (avg_precision + avg_recall)
|
||||
|
||||
return {
|
||||
"precision": avg_precision,
|
||||
"recall": avg_recall,
|
||||
"f1": f1,
|
||||
"samples": results["total"]
|
||||
}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue 1: Adapter Not Affecting Predictions
|
||||
|
||||
**Symptom**: Predictions are the same with and without adapter.
|
||||
|
||||
**Solution**:
|
||||
```python
|
||||
# Check if adapter is actually loaded
|
||||
print(f"Has adapter: {model.has_adapter}")
|
||||
|
||||
# Check LoRA layers
|
||||
from gliner2.training.lora import LoRALayer
|
||||
lora_count = sum(1 for m in model.modules() if isinstance(m, LoRALayer))
|
||||
print(f"LoRA layers: {lora_count}")
|
||||
|
||||
# Should be > 0 if adapter is loaded
|
||||
assert lora_count > 0, "No LoRA layers found!"
|
||||
```
|
||||
|
||||
### Issue 2: Out of Memory During Training
|
||||
|
||||
**Solution**:
|
||||
```python
|
||||
config = TrainingConfig(
|
||||
# Reduce batch size
|
||||
batch_size=4, # Instead of 8
|
||||
gradient_accumulation_steps=4, # Maintain effective batch size
|
||||
|
||||
# Use smaller LoRA rank
|
||||
lora_r=4, # Instead of 8
|
||||
|
||||
# Enable mixed precision
|
||||
fp16=True,
|
||||
|
||||
# Target only attention layers (fewer parameters)
|
||||
lora_target_modules=["encoder.query", "encoder.key", "encoder.value"],
|
||||
)
|
||||
```
|
||||
|
||||
### Issue 3: Adapter File Not Found
|
||||
|
||||
**Solution**:
|
||||
```python
|
||||
import os
|
||||
from gliner2.training.lora import LoRAAdapterConfig
|
||||
|
||||
adapter_path = "./adapters/legal_adapter/final"
|
||||
|
||||
# Check if path exists
|
||||
if not os.path.exists(adapter_path):
|
||||
print(f"Path does not exist: {adapter_path}")
|
||||
# List available checkpoints
|
||||
checkpoint_dir = "./adapters/legal_adapter"
|
||||
if os.path.exists(checkpoint_dir):
|
||||
checkpoints = os.listdir(checkpoint_dir)
|
||||
print(f"Available checkpoints: {checkpoints}")
|
||||
|
||||
# Check if it's a valid adapter
|
||||
if LoRAAdapterConfig.is_adapter_path(adapter_path):
|
||||
print("Valid adapter path!")
|
||||
config = LoRAAdapterConfig.load(adapter_path)
|
||||
print(f"Adapter config: {config}")
|
||||
else:
|
||||
print("Not a valid adapter path!")
|
||||
```
|
||||
|
||||
### Issue 4: Slow Adapter Switching
|
||||
|
||||
**Problem**: Switching between adapters takes too long.
|
||||
|
||||
**Solution**:
|
||||
```python
|
||||
# Pre-load adapters in memory (if you have enough RAM)
|
||||
adapters = {}
|
||||
for domain, path in adapter_paths.items():
|
||||
# Load adapter weights into memory
|
||||
adapters[domain] = load_adapter_to_memory(path)
|
||||
|
||||
# Fast switching from memory (not implemented in base API,
|
||||
# but possible with custom caching layer)
|
||||
```
|
||||
|
||||
## Summary
|
||||
|
||||
### Key Takeaways
|
||||
|
||||
✅ **LoRA adapters** enable efficient multi-domain inference
|
||||
✅ **Training** is 2-3x faster than full fine-tuning
|
||||
✅ **Storage** savings of 65-95% compared to multiple full models
|
||||
✅ **Swapping** adapters takes < 1 second
|
||||
✅ **Domain specialization** improves accuracy on specific tasks
|
||||
|
||||
### Quick Reference
|
||||
|
||||
```python
|
||||
# Training
|
||||
config = TrainingConfig(
|
||||
use_lora=True,
|
||||
lora_r=8,
|
||||
lora_alpha=16.0,
|
||||
save_adapter_only=True,
|
||||
)
|
||||
trainer.train(train_data=examples)
|
||||
|
||||
# Loading
|
||||
model = GLiNER2.from_pretrained("base-model")
|
||||
model.load_adapter("./adapter/final")
|
||||
|
||||
# Swapping
|
||||
model.load_adapter("./other_adapter/final")
|
||||
|
||||
# Unloading
|
||||
model.unload_adapter()
|
||||
|
||||
# Checking
|
||||
print(model.has_adapter)
|
||||
print(model.adapter_config)
|
||||
```
|
||||
|
||||
### Next Steps
|
||||
|
||||
1. **Train your first adapter** with domain-specific data
|
||||
2. **Evaluate performance** on test set
|
||||
3. **Experiment with hyperparameters** (rank, alpha, target modules)
|
||||
4. **Deploy multiple adapters** for different use cases
|
||||
5. **Monitor and iterate** based on real-world performance
|
||||
|
||||
For more information:
|
||||
- LoRA Paper: https://arxiv.org/abs/2106.09685
|
||||
- Implementation: `gliner2/training/lora.py`
|
||||
- Tests: `tests/test_lora_adapters.py`
|
||||
- Verification Guide: `LORA_VERIFICATION_TESTS.md`
|
||||
|
||||
201
packages/GLiNER2/tutorial/11-adapter_switching.md
Normal file
201
packages/GLiNER2/tutorial/11-adapter_switching.md
Normal file
@ -0,0 +1,201 @@
|
||||
# Tutorial 11: LoRA Adapter Switching/Routing
|
||||
|
||||
## Quick Start
|
||||
|
||||
Switch between domain-specific adapters during inference without reloading the base model.
|
||||
|
||||
```python
|
||||
from gliner2 import GLiNER2
|
||||
|
||||
# Load base model once
|
||||
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
|
||||
|
||||
# Load legal adapter
|
||||
model.load_adapter("./legal_adapter")
|
||||
legal_result = model.extract_entities("Apple sued Google", ["company"])
|
||||
|
||||
# Switch to medical adapter
|
||||
model.load_adapter("./medical_adapter")
|
||||
medical_result = model.extract_entities("Patient has diabetes", ["disease"])
|
||||
|
||||
# Use base model (no adapter)
|
||||
model.unload_adapter()
|
||||
base_result = model.extract_entities("Some text", ["entity"])
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Loading an Adapter
|
||||
|
||||
```python
|
||||
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
|
||||
model.load_adapter("./path/to/adapter")
|
||||
```
|
||||
|
||||
The adapter path should point to a directory containing:
|
||||
- `adapter_config.json`
|
||||
- `adapter_weights.safetensors`
|
||||
|
||||
### Checking Adapter Status
|
||||
|
||||
```python
|
||||
# Check if adapter is loaded
|
||||
if model.has_adapter:
|
||||
print("Adapter is loaded")
|
||||
|
||||
# Get adapter configuration
|
||||
config = model.adapter_config
|
||||
print(f"LoRA rank: {config.lora_r}")
|
||||
```
|
||||
|
||||
### Unloading an Adapter
|
||||
|
||||
```python
|
||||
# Remove adapter, use base model
|
||||
model.unload_adapter()
|
||||
```
|
||||
|
||||
## Switching Between Adapters
|
||||
|
||||
Adapters automatically swap when you call `load_adapter()`:
|
||||
|
||||
```python
|
||||
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
|
||||
|
||||
# Legal domain
|
||||
model.load_adapter("./legal_adapter")
|
||||
result1 = model.extract_entities("Apple Inc. filed suit", ["company"])
|
||||
|
||||
# Medical domain (previous adapter auto-unloaded)
|
||||
model.load_adapter("./medical_adapter")
|
||||
result2 = model.extract_entities("Patient has diabetes", ["disease"])
|
||||
|
||||
# Support domain
|
||||
model.load_adapter("./support_adapter")
|
||||
result3 = model.extract_entities("Order #12345 issue", ["order_id"])
|
||||
```
|
||||
|
||||
## Routing by Document Type
|
||||
|
||||
Route documents to appropriate adapters:
|
||||
|
||||
```python
|
||||
def extract_with_routing(model, text, doc_type, adapters):
|
||||
"""Route document to domain-specific adapter."""
|
||||
adapter_path = adapters.get(doc_type)
|
||||
|
||||
if adapter_path:
|
||||
model.load_adapter(adapter_path)
|
||||
else:
|
||||
model.unload_adapter() # Use base model
|
||||
|
||||
# Define entity types per domain
|
||||
entity_types = {
|
||||
"legal": ["company", "person", "law"],
|
||||
"medical": ["disease", "drug", "symptom"],
|
||||
"support": ["order_id", "customer", "issue"]
|
||||
}
|
||||
|
||||
return model.extract_entities(
|
||||
text,
|
||||
entity_types.get(doc_type, ["entity"])
|
||||
)
|
||||
|
||||
# Setup
|
||||
model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
|
||||
adapters = {
|
||||
"legal": "./legal_adapter",
|
||||
"medical": "./medical_adapter",
|
||||
"support": "./support_adapter"
|
||||
}
|
||||
|
||||
# Use
|
||||
result = extract_with_routing(
|
||||
model,
|
||||
"Apple sued Google",
|
||||
"legal",
|
||||
adapters
|
||||
)
|
||||
```
|
||||
|
||||
## Batch Processing by Domain
|
||||
|
||||
Process multiple documents efficiently:
|
||||
|
||||
```python
|
||||
def process_by_domain(model, documents, adapters):
|
||||
"""Process documents grouped by domain."""
|
||||
results = {}
|
||||
|
||||
for domain, docs in documents.items():
|
||||
# Load domain adapter
|
||||
model.load_adapter(adapters[domain])
|
||||
|
||||
# Process all documents for this domain
|
||||
results[domain] = [
|
||||
model.extract_entities(doc, get_entity_types(domain))
|
||||
for doc in docs
|
||||
]
|
||||
|
||||
return results
|
||||
|
||||
# Example
|
||||
documents = {
|
||||
"legal": ["Apple sued Samsung", "Microsoft acquired LinkedIn"],
|
||||
"medical": ["Patient has diabetes", "Prescribed Metformin"]
|
||||
}
|
||||
|
||||
adapters = {
|
||||
"legal": "./legal_adapter",
|
||||
"medical": "./medical_adapter"
|
||||
}
|
||||
|
||||
results = process_by_domain(model, documents, adapters)
|
||||
```
|
||||
|
||||
## Simple Router Class
|
||||
|
||||
```python
|
||||
class AdapterRouter:
|
||||
"""Simple adapter router for multi-domain inference."""
|
||||
|
||||
def __init__(self, base_model_name, adapters):
|
||||
self.model = GLiNER2.from_pretrained(base_model_name)
|
||||
self.adapters = adapters
|
||||
self.current_domain = None
|
||||
|
||||
def extract(self, text, domain, entity_types):
|
||||
"""Extract entities using domain-specific adapter."""
|
||||
# Load adapter if domain changed
|
||||
if self.current_domain != domain:
|
||||
adapter_path = self.adapters.get(domain)
|
||||
if adapter_path:
|
||||
self.model.load_adapter(adapter_path)
|
||||
else:
|
||||
self.model.unload_adapter()
|
||||
self.current_domain = domain
|
||||
|
||||
return self.model.extract_entities(text, entity_types)
|
||||
|
||||
# Usage
|
||||
router = AdapterRouter(
|
||||
"fastino/gliner2-base-v1",
|
||||
{
|
||||
"legal": "./legal_adapter",
|
||||
"medical": "./medical_adapter"
|
||||
}
|
||||
)
|
||||
|
||||
result = router.extract("Apple sued Google", "legal", ["company"])
|
||||
```
|
||||
|
||||
## Summary
|
||||
|
||||
- **Load adapter**: `model.load_adapter(path)`
|
||||
- **Unload adapter**: `model.unload_adapter()`
|
||||
- **Check status**: `model.has_adapter`
|
||||
- **Get config**: `model.adapter_config`
|
||||
- **Auto-swap**: Loading a new adapter automatically unloads the previous one
|
||||
|
||||
For training adapters, see [Tutorial 10: LoRA Adapters](10-lora_adapters.md).
|
||||
|
||||
372
packages/GLiNER2/tutorial/2-ner.md
Normal file
372
packages/GLiNER2/tutorial/2-ner.md
Normal file
@ -0,0 +1,372 @@
|
||||
# GLiNER2 Entity Extraction Tutorial
|
||||
|
||||
Learn how to extract named entities from text using GLiNER2's flexible entity recognition capabilities.
|
||||
|
||||
## Table of Contents
|
||||
- [Basic Entity Extraction](#basic-entity-extraction)
|
||||
- [Entity Extraction with Descriptions](#entity-extraction-with-descriptions)
|
||||
- [Single vs Multiple Entities](#single-vs-multiple-entities)
|
||||
- [Custom Thresholds](#custom-thresholds)
|
||||
- [Advanced Configuration](#advanced-configuration)
|
||||
- [Domain-Specific Entities](#domain-specific-entities)
|
||||
- [Best Practices](#best-practices)
|
||||
|
||||
## Basic Entity Extraction
|
||||
|
||||
### Simple Example
|
||||
|
||||
```python
|
||||
from gliner2 import GLiNER2
|
||||
|
||||
# Load model
|
||||
extractor = GLiNER2.from_pretrained("your-model-name")
|
||||
|
||||
# Extract common entities
|
||||
text = "Apple Inc. CEO Tim Cook announced the new iPhone 15 in Cupertino, California on September 12, 2023."
|
||||
results = extractor.extract_entities(
|
||||
text,
|
||||
["company", "person", "product", "location", "date"]
|
||||
)
|
||||
print(results)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'company': ['Apple Inc.'],
|
||||
# 'person': ['Tim Cook'],
|
||||
# 'product': ['iPhone 15'],
|
||||
# 'location': ['Cupertino', 'California'],
|
||||
# 'date': ['September 12, 2023']
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
### Using Schema Builder
|
||||
|
||||
```python
|
||||
# Same extraction using schema
|
||||
schema = extractor.create_schema().entities([
|
||||
"company", "person", "product", "location", "date"
|
||||
])
|
||||
results = extractor.extract(text, schema)
|
||||
```
|
||||
|
||||
## Entity Extraction with Descriptions
|
||||
|
||||
Descriptions significantly improve extraction accuracy by providing context.
|
||||
|
||||
```python
|
||||
# Medical entity extraction
|
||||
schema = extractor.create_schema().entities({
|
||||
"drug": "Pharmaceutical drugs, medications, or treatment names",
|
||||
"disease": "Medical conditions, illnesses, or disorders",
|
||||
"symptom": "Clinical symptoms or patient-reported symptoms",
|
||||
"dosage": "Medication amounts like '50mg' or '2 tablets daily'",
|
||||
"organ": "Body parts or organs mentioned in medical context"
|
||||
})
|
||||
|
||||
medical_text = """
|
||||
Patient was prescribed Metformin 500mg twice daily for Type 2 Diabetes.
|
||||
She reported fatigue and occasional dizziness. Liver function tests ordered.
|
||||
"""
|
||||
|
||||
results = extractor.extract(medical_text, schema)
|
||||
print(results)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'drug': ['Metformin'],
|
||||
# 'disease': ['Type 2 Diabetes'],
|
||||
# 'symptom': ['fatigue', 'dizziness'],
|
||||
# 'dosage': ['500mg twice daily'],
|
||||
# 'organ': ['Liver']
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
## Single vs Multiple Entities
|
||||
|
||||
Control whether to extract one or multiple entities per type.
|
||||
|
||||
### Multiple Entities (Default)
|
||||
|
||||
```python
|
||||
# Default behavior - extracts all matching entities
|
||||
schema = extractor.create_schema().entities(
|
||||
["person", "organization"],
|
||||
dtype="list" # Default
|
||||
)
|
||||
|
||||
text = "Bill Gates and Steve Jobs founded Microsoft and Apple respectively."
|
||||
results = extractor.extract(text, schema)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'person': ['Bill Gates', 'Steve Jobs'],
|
||||
# 'organization': ['Microsoft', 'Apple']
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
### Single Entity per Type
|
||||
|
||||
```python
|
||||
# Extract only the best match per entity type
|
||||
schema = extractor.create_schema().entities(
|
||||
["company", "ceo"],
|
||||
dtype="str" # Single entity mode
|
||||
)
|
||||
|
||||
text = "Apple CEO Tim Cook met with Microsoft CEO Satya Nadella."
|
||||
results = extractor.extract(text, schema)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'company': 'Apple', # Just one, despite multiple in text
|
||||
# 'ceo': 'Tim Cook' # Just one
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
## Custom Thresholds
|
||||
|
||||
Set confidence thresholds for precise control.
|
||||
|
||||
### Global Threshold
|
||||
|
||||
```python
|
||||
# High-precision extraction
|
||||
results = extractor.extract_entities(
|
||||
text,
|
||||
["email", "phone", "address"],
|
||||
threshold=0.8 # High confidence required
|
||||
)
|
||||
```
|
||||
|
||||
### With Confidence Scores and Character Positions
|
||||
|
||||
You can include confidence scores and character-level start/end positions using `include_confidence` and `include_spans` parameters:
|
||||
|
||||
```python
|
||||
# Extract entities with confidence scores
|
||||
text = "Apple Inc. CEO Tim Cook announced iPhone 15 in Cupertino."
|
||||
results = extractor.extract_entities(
|
||||
text,
|
||||
["company", "person", "product"],
|
||||
include_confidence=True
|
||||
)
|
||||
print(results)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'company': [
|
||||
# {'text': 'Apple Inc.', 'confidence': 0.95},
|
||||
# {'text': 'Tim Cook', 'confidence': 0.92}
|
||||
# ],
|
||||
# 'product': [
|
||||
# {'text': 'iPhone 15', 'confidence': 0.88}
|
||||
# ]
|
||||
# }
|
||||
# }
|
||||
|
||||
# Extract with character positions (spans)
|
||||
results = extractor.extract_entities(
|
||||
text,
|
||||
["company", "person"],
|
||||
include_spans=True
|
||||
)
|
||||
print(results)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'company': [
|
||||
# {'text': 'Apple Inc.', 'start': 0, 'end': 9}
|
||||
# ],
|
||||
# 'person': [
|
||||
# {'text': 'Tim Cook', 'start': 15, 'end': 23}
|
||||
# ]
|
||||
# }
|
||||
# }
|
||||
|
||||
# Extract with both confidence and spans
|
||||
results = extractor.extract_entities(
|
||||
text,
|
||||
["company", "product"],
|
||||
include_confidence=True,
|
||||
include_spans=True
|
||||
)
|
||||
print(results)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'company': [
|
||||
# {'text': 'Apple Inc.', 'confidence': 0.95, 'start': 0, 'end': 9}
|
||||
# ],
|
||||
# 'product': [
|
||||
# {'text': 'confidence': 0.88, 'start': 15, 'end': 24}
|
||||
# ]
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
**Note**: When `include_spans` is True, the output format changes:
|
||||
- **Default** (both False): Returns simple text strings: `['Apple Inc.', 'Tim Cook']`
|
||||
- **include_confidence=True**: Returns dicts with `{'text': '...', 'confidence': 0.95}`
|
||||
- **include_spans=True**: Returns dicts with `{'text': '...', 'start': 0, 'end': 9}
|
||||
- **Both True**: Returns dicts with `{'text': '...', 'confidence': 0.95, 'start': 0, 'end': 9}
|
||||
|
||||
### Per-Entity Thresholds
|
||||
|
||||
```python
|
||||
# Different thresholds for different entities
|
||||
schema = extractor.create_schema().entities({
|
||||
"email": {
|
||||
"description": "Email addresses",
|
||||
"dtype": "list",
|
||||
"threshold": 0.9 # Very high precision for emails
|
||||
},
|
||||
"phone": {
|
||||
"description": "Phone numbers including mobile and landline",
|
||||
"dtype": "list",
|
||||
"threshold": 0.7 # Moderate threshold
|
||||
},
|
||||
"name": {
|
||||
"description": "Person names",
|
||||
"dtype": "list",
|
||||
"threshold": 0.5 # Lower threshold for names
|
||||
}
|
||||
})
|
||||
|
||||
contact_text = "Contact John Doe at john.doe@email.com or call 555-1234."
|
||||
results = extractor.extract(contact_text, schema, threshold=0.6) # Default threshold
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Mixed Configuration
|
||||
|
||||
```python
|
||||
# Combine different entity configurations
|
||||
schema = extractor.create_schema()
|
||||
|
||||
# Add simple entities
|
||||
schema.entities(["date", "time", "currency"])
|
||||
|
||||
# Add entities with descriptions
|
||||
schema.entities({
|
||||
"technical_term": "Technical jargon or specialized terminology",
|
||||
"metric": "Measurements, KPIs, or quantitative values"
|
||||
})
|
||||
|
||||
# Add entities with full configuration
|
||||
schema.entities({
|
||||
"competitor": {
|
||||
"description": "Competing companies or products",
|
||||
"dtype": "list",
|
||||
"threshold": 0.7
|
||||
},
|
||||
"revenue": {
|
||||
"description": "Revenue figures or financial amounts",
|
||||
"dtype": "str", # Only extract one
|
||||
"threshold": 0.8
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### Incremental Entity Addition
|
||||
|
||||
```python
|
||||
# Build schema incrementally
|
||||
schema = extractor.create_schema()
|
||||
|
||||
# Add entities in stages
|
||||
schema.entities(["person", "location"]) # Basic entities
|
||||
schema.entities({"company": "Company or organization names"}) # With description
|
||||
schema.entities({ # With full config
|
||||
"financial_term": {
|
||||
"description": "Financial instruments, metrics, or terminology",
|
||||
"threshold": 0.75
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
## Domain-Specific Entities
|
||||
|
||||
### Legal Entities
|
||||
|
||||
```python
|
||||
legal_schema = extractor.create_schema().entities({
|
||||
"party": "Parties involved in legal proceedings (plaintiff, defendant, etc.)",
|
||||
"law_firm": "Law firm or legal practice names",
|
||||
"court": "Court names or judicial bodies",
|
||||
"statute": "Legal statutes, laws, or regulations cited",
|
||||
"case": "Legal case names or citations",
|
||||
"judge": "Names of judges or magistrates",
|
||||
"legal_term": "Legal terminology or concepts"
|
||||
})
|
||||
|
||||
legal_text = """
|
||||
In the case of Smith v. Jones, Judge Sarah Williams of the Superior Court
|
||||
ruled that the defendant violated Section 15.2 of the Consumer Protection Act.
|
||||
The plaintiff was represented by Miller & Associates.
|
||||
"""
|
||||
results = extractor.extract(legal_text, legal_schema)
|
||||
```
|
||||
|
||||
### Financial Entities
|
||||
|
||||
```python
|
||||
finance_schema = extractor.create_schema().entities({
|
||||
"ticker": "Stock ticker symbols (e.g., AAPL, GOOGL)",
|
||||
"financial_metric": "Financial metrics like P/E ratio, market cap",
|
||||
"currency_amount": "Monetary values with currency symbols",
|
||||
"percentage": "Percentage values (e.g., 5.2%, -3%)",
|
||||
"financial_org": "Banks, investment firms, financial institutions",
|
||||
"market_index": "Stock market indices (S&P 500, NASDAQ, etc.)"
|
||||
})
|
||||
|
||||
finance_text = """
|
||||
AAPL rose 3.5% to $185.50 after beating earnings expectations.
|
||||
The company's P/E ratio of 28.5 attracted Goldman Sachs analysts.
|
||||
The NASDAQ composite gained 1.2% for the day.
|
||||
"""
|
||||
results = extractor.extract(finance_text, finance_schema)
|
||||
```
|
||||
|
||||
### Scientific Entities
|
||||
|
||||
```python
|
||||
science_schema = extractor.create_schema().entities({
|
||||
"chemical": "Chemical compounds or elements",
|
||||
"organism": "Biological organisms, species names",
|
||||
"gene": "Gene names or identifiers",
|
||||
"measurement": "Scientific measurements with units",
|
||||
"research_method": "Research techniques or methodologies",
|
||||
"institution": "Universities or research institutions"
|
||||
})
|
||||
|
||||
science_text = """
|
||||
Researchers at MIT discovered that the BRCA1 gene mutation increases
|
||||
cancer risk by 70%. The study used CRISPR-Cas9 to modify DNA sequences
|
||||
in Mus musculus specimens, measuring tumor growth in millimeters.
|
||||
"""
|
||||
results = extractor.extract(science_text, science_schema)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Descriptive Entity Names
|
||||
|
||||
```python
|
||||
# Good - Clear, specific entity types
|
||||
schema.entities(["drug_name", "medical_device", "procedure_name"])
|
||||
|
||||
# Less ideal - Too generic
|
||||
schema.entities(["thing", "item", "stuff"])
|
||||
```
|
||||
|
||||
### 2. Provide Context with Descriptions
|
||||
|
||||
```python
|
||||
# Good - Clear descriptions
|
||||
schema.entities({
|
||||
"acquisition_company": "Company that is acquiring another company",
|
||||
"target_company": "Company being acquired",
|
||||
"acquisition_price": "Purchase price or valuation of acquisition"
|
||||
})
|
||||
|
||||
# Less ideal - No context
|
||||
schema.entities(["company1", "company2", "price"])
|
||||
```
|
||||
504
packages/GLiNER2/tutorial/3-json_extraction.md
Normal file
504
packages/GLiNER2/tutorial/3-json_extraction.md
Normal file
@ -0,0 +1,504 @@
|
||||
# GLiNER2 JSON Structure Extraction Tutorial
|
||||
|
||||
Learn how to extract complex structured data from text using GLiNER2's hierarchical extraction capabilities.
|
||||
|
||||
## Table of Contents
|
||||
- [Quick API with extract_json](#quick-api-with-extract_json)
|
||||
- [Field Types and Specifications](#field-types-and-specifications)
|
||||
- [Multiple Instances](#multiple-instances)
|
||||
- [Schema Builder (Multi-Task)](#schema-builder-multi-task)
|
||||
- [Real-World Examples](#real-world-examples)
|
||||
- [Best Practices](#best-practices)
|
||||
|
||||
## Quick API with extract_json
|
||||
|
||||
For structure-only extraction, use the `extract_json()` method with the simple dictionary format:
|
||||
|
||||
### Basic Structure Extraction
|
||||
|
||||
```python
|
||||
from gliner2 import GLiNER2
|
||||
|
||||
# Load model
|
||||
extractor = GLiNER2.from_pretrained("your-model-name")
|
||||
|
||||
# Simple product extraction
|
||||
text = "The MacBook Pro costs $1999 and features M3 chip, 16GB RAM, and 512GB storage."
|
||||
results = extractor.extract_json(
|
||||
text,
|
||||
{
|
||||
"product": [
|
||||
"name::str",
|
||||
"price",
|
||||
"features"
|
||||
]
|
||||
}
|
||||
)
|
||||
print(results)
|
||||
# Output: {
|
||||
# 'product': [{
|
||||
# 'name': 'MacBook Pro',
|
||||
# 'price': ['$1999'],
|
||||
# 'features': ['M3 chip', '16GB RAM', '512GB storage']
|
||||
# }]
|
||||
# }
|
||||
```
|
||||
|
||||
### Contact Information
|
||||
|
||||
```python
|
||||
text = """
|
||||
Contact: John Smith
|
||||
Email: john@example.com
|
||||
Phones: 555-1234, 555-5678
|
||||
Address: 123 Main St, NYC
|
||||
"""
|
||||
|
||||
results = extractor.extract_json(
|
||||
text,
|
||||
{
|
||||
"contact": [
|
||||
"name::str",
|
||||
"email::str",
|
||||
"phone::list",
|
||||
"address"
|
||||
]
|
||||
}
|
||||
)
|
||||
# Output: {
|
||||
# 'contact': [{
|
||||
# 'name': 'John Smith',
|
||||
# 'email': 'john@example.com',
|
||||
# 'phone': ['555-1234', '555-5678'],
|
||||
# 'address': ['123 Main St, NYC']
|
||||
# }]
|
||||
# }
|
||||
```
|
||||
|
||||
## Field Types and Specifications
|
||||
|
||||
### Field Specification Format
|
||||
|
||||
Fields support flexible specifications using `::` separators:
|
||||
|
||||
```
|
||||
"field_name::type::description"
|
||||
"field_name::[choice1|choice2|choice3]::type::description"
|
||||
"field_name::description" # defaults to list type
|
||||
"field_name" # simple field, defaults to list
|
||||
```
|
||||
|
||||
### String vs List Fields
|
||||
|
||||
```python
|
||||
text = """
|
||||
Tech Conference 2024 on June 15th in San Francisco.
|
||||
Topics include AI, Machine Learning, and Cloud Computing.
|
||||
Registration fee: $299 for early bird tickets.
|
||||
"""
|
||||
|
||||
results = extractor.extract_json(
|
||||
text,
|
||||
{
|
||||
"event": [
|
||||
"name::str::Event or conference name",
|
||||
"date::str::Event date",
|
||||
"location::str",
|
||||
"topics::list::Conference topics",
|
||||
"registration_fee::str"
|
||||
]
|
||||
}
|
||||
)
|
||||
# Output: {
|
||||
# 'event': [{
|
||||
# 'name': 'Tech Conference 2024',
|
||||
# 'date': 'June 15th',
|
||||
# 'location': 'San Francisco',
|
||||
# 'topics': ['AI', 'Machine Learning', 'Cloud Computing'],
|
||||
# 'registration_fee': '$299'
|
||||
# }]
|
||||
# }
|
||||
```
|
||||
|
||||
### Choice Fields (Classification within Structure)
|
||||
|
||||
```python
|
||||
text = """
|
||||
Reservation at Le Bernardin for 4 people on March 15th at 7:30 PM.
|
||||
We'd prefer outdoor seating. Two guests are vegetarian and one is gluten-free.
|
||||
"""
|
||||
|
||||
results = extractor.extract_json(
|
||||
text,
|
||||
{
|
||||
"reservation": [
|
||||
"restaurant::str::Restaurant name",
|
||||
"date::str",
|
||||
"time::str",
|
||||
"party_size::[1|2|3|4|5|6+]::str::Number of guests",
|
||||
"seating::[indoor|outdoor|bar]::str::Seating preference",
|
||||
"dietary::[vegetarian|vegan|gluten-free|none]::list::Dietary restrictions"
|
||||
]
|
||||
}
|
||||
)
|
||||
# Output: {
|
||||
# 'reservation': [{
|
||||
# 'restaurant': 'Le Bernardin',
|
||||
# 'date': 'March 15th',
|
||||
# 'time': '7:30 PM',
|
||||
# 'party_size': '4',
|
||||
# 'seating': 'outdoor',
|
||||
# 'dietary': ['vegetarian', 'gluten-free']
|
||||
# }]
|
||||
# }
|
||||
```
|
||||
|
||||
## Multiple Instances
|
||||
|
||||
GLiNER2 automatically extracts ALL instances of a structure found in text:
|
||||
|
||||
### Multiple Transactions
|
||||
|
||||
```python
|
||||
text = """
|
||||
Recent transactions:
|
||||
- Jan 5: Starbucks $5.50 (food)
|
||||
- Jan 5: Uber $23.00 (transport)
|
||||
- Jan 6: Amazon $156.99 (shopping)
|
||||
"""
|
||||
|
||||
results = extractor.extract_json(
|
||||
text,
|
||||
{
|
||||
"transaction": [
|
||||
"date::str",
|
||||
"merchant::str",
|
||||
"amount::str",
|
||||
"category::[food|transport|shopping|utilities]::str"
|
||||
]
|
||||
}
|
||||
)
|
||||
# Output: {
|
||||
# 'transaction': [
|
||||
# {'date': 'Jan 5', 'merchant': 'Starbucks', 'amount': '$5.50', 'category': 'food'},
|
||||
# {'date': 'Jan 5', 'merchant': 'Uber', 'amount': '$23.00', 'category': 'transport'},
|
||||
# {'date': 'Jan 6', 'merchant': 'Amazon', 'amount': '$156.99', 'category': 'shopping'}
|
||||
# ]
|
||||
# }
|
||||
```
|
||||
|
||||
### Multiple Hotel Bookings
|
||||
|
||||
```python
|
||||
text = """
|
||||
Alice Brown booked the Hilton Downtown from March 10 to March 12. She selected a double room
|
||||
for $340 total with breakfast and parking included.
|
||||
|
||||
Robert Taylor reserved The Grand Hotel, April 1 to April 5, suite at $1,200 total.
|
||||
Amenities include breakfast, wifi, gym, and spa access.
|
||||
"""
|
||||
|
||||
results = extractor.extract_json(
|
||||
text,
|
||||
{
|
||||
"booking": [
|
||||
"guest::str::Guest name",
|
||||
"hotel::str::Hotel name",
|
||||
"check_in::str",
|
||||
"check_out::str",
|
||||
"room_type::[single|double|suite|deluxe]::str",
|
||||
"total_price::str",
|
||||
"amenities::[breakfast|wifi|parking|gym|spa]::list"
|
||||
]
|
||||
}
|
||||
)
|
||||
# Output: {
|
||||
# 'booking': [
|
||||
# {
|
||||
# 'guest': 'Alice Brown',
|
||||
# 'hotel': 'Hilton Downtown',
|
||||
# 'check_in': 'March 10',
|
||||
# 'check_out': 'March 12',
|
||||
# 'room_type': 'double',
|
||||
# 'total_price': '$340',
|
||||
# 'amenities': ['breakfast', 'parking']
|
||||
# },
|
||||
# {
|
||||
# 'guest': 'Robert Taylor',
|
||||
# 'hotel': 'The Grand Hotel',
|
||||
# 'check_in': 'April 1',
|
||||
# 'check_out': 'April 5',
|
||||
# 'room_type': 'suite',
|
||||
# 'total_price': '$1,200',
|
||||
# 'amenities': ['breakfast', 'wifi', 'gym', 'spa']
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
```
|
||||
|
||||
## Schema Builder (Multi-Task)
|
||||
|
||||
Use `create_schema()` only when combining structured extraction with other tasks (entities, classification):
|
||||
|
||||
### Multi-Task Extraction
|
||||
|
||||
```python
|
||||
# Use schema builder for multi-task scenarios
|
||||
schema = (extractor.create_schema()
|
||||
# Extract entities
|
||||
.entities(["person", "company", "location"])
|
||||
|
||||
# Classify sentiment
|
||||
.classification("sentiment", ["positive", "negative", "neutral"])
|
||||
|
||||
# Extract structured product info
|
||||
.structure("product")
|
||||
.field("name", dtype="str")
|
||||
.field("price", dtype="str")
|
||||
.field("features", dtype="list")
|
||||
.field("category", dtype="str", choices=["electronics", "software", "service"])
|
||||
)
|
||||
|
||||
text = "Apple CEO Tim Cook announced iPhone 15 for $999 with amazing new features. This is exciting!"
|
||||
results = extractor.extract(text, schema)
|
||||
# Output: {
|
||||
# 'entities': {'person': ['Tim Cook'], 'company': ['Apple'], 'location': []},
|
||||
# 'sentiment': 'positive',
|
||||
# 'product': [{
|
||||
# 'name': 'iPhone 15',
|
||||
# 'price': '$999',
|
||||
# 'features': ['amazing new features'],
|
||||
# 'category': 'electronics'
|
||||
# }]
|
||||
# }
|
||||
```
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
```python
|
||||
schema = (extractor.create_schema()
|
||||
.classification("urgency", ["low", "medium", "high"])
|
||||
|
||||
.structure("support_ticket")
|
||||
.field("ticket_id", dtype="str", threshold=0.9) # High precision
|
||||
.field("customer", dtype="str", description="Customer name")
|
||||
.field("issue", dtype="str", description="Problem description")
|
||||
.field("priority", dtype="str", choices=["low", "medium", "high", "urgent"])
|
||||
.field("tags", dtype="list", choices=["bug", "feature", "support", "billing"])
|
||||
)
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Financial Transaction Processing
|
||||
|
||||
```python
|
||||
text = """
|
||||
Goldman Sachs processed a $2.5M equity trade for Tesla Inc. on March 15, 2024.
|
||||
Commission: $1,250. Status: Completed.
|
||||
"""
|
||||
|
||||
results = extractor.extract_json(
|
||||
text,
|
||||
{
|
||||
"transaction": [
|
||||
"broker::str::Financial institution",
|
||||
"amount::str::Transaction amount",
|
||||
"security::str::Stock or financial instrument",
|
||||
"date::str::Transaction date",
|
||||
"commission::str::Fees charged",
|
||||
"status::[pending|completed|failed]::str",
|
||||
"type::[equity|bond|option|future]::str"
|
||||
]
|
||||
}
|
||||
)
|
||||
# Output: {
|
||||
# 'transaction': [{
|
||||
# 'broker': 'Goldman Sachs',
|
||||
# 'amount': '$2.5M',
|
||||
# 'security': 'Tesla Inc.',
|
||||
# 'date': 'March 15, 2024',
|
||||
# 'commission': '$1,250',
|
||||
# 'status': 'completed',
|
||||
# 'type': 'equity'
|
||||
# }]
|
||||
# }
|
||||
```
|
||||
|
||||
### Medical Prescription Extraction
|
||||
|
||||
```python
|
||||
text = """
|
||||
Patient: Sarah Johnson, 34, presented with chest pain.
|
||||
Prescribed: Lisinopril 10mg daily, Metoprolol 25mg twice daily.
|
||||
Follow-up scheduled for next Tuesday.
|
||||
"""
|
||||
|
||||
results = extractor.extract_json(
|
||||
text,
|
||||
{
|
||||
"patient": [
|
||||
"name::str::Patient full name",
|
||||
"age::str::Patient age",
|
||||
"symptoms::list::Reported symptoms"
|
||||
],
|
||||
"prescription": [
|
||||
"medication::str::Drug name",
|
||||
"dosage::str::Dosage amount",
|
||||
"frequency::str::How often to take"
|
||||
]
|
||||
}
|
||||
)
|
||||
# Output: {
|
||||
# 'patient': [{
|
||||
# 'name': 'Sarah Johnson',
|
||||
# 'age': '34',
|
||||
# 'symptoms': ['chest pain']
|
||||
# }],
|
||||
# 'prescription': [
|
||||
# {'medication': 'Lisinopril', 'dosage': '10mg', 'frequency': 'daily'},
|
||||
# {'medication': 'Metoprolol', 'dosage': '25mg', 'frequency': 'twice daily'}
|
||||
# ]
|
||||
# }
|
||||
```
|
||||
|
||||
### E-commerce Order Processing
|
||||
|
||||
```python
|
||||
text = """
|
||||
Order #ORD-2024-001 for Alexandra Thompson
|
||||
Items: Laptop Stand (2x $45.99), Wireless Mouse (1x $29.99), USB Hub (3x $35.50)
|
||||
Subtotal: $228.46, Tax: $18.28, Total: $246.74
|
||||
Status: Processing
|
||||
"""
|
||||
|
||||
results = extractor.extract_json(
|
||||
text,
|
||||
{
|
||||
"order": [
|
||||
"order_id::str::Order number",
|
||||
"customer::str::Customer name",
|
||||
"items::list::Product names",
|
||||
"quantities::list::Item quantities",
|
||||
"unit_prices::list::Individual prices",
|
||||
"subtotal::str",
|
||||
"tax::str",
|
||||
"total::str",
|
||||
"status::[pending|processing|shipped|delivered]::str"
|
||||
]
|
||||
}
|
||||
)
|
||||
# Output: {
|
||||
# 'order': [{
|
||||
# 'order_id': 'ORD-2024-001',
|
||||
# 'customer': 'Alexandra Thompson',
|
||||
# 'items': ['Laptop Stand', 'Wireless Mouse', 'USB Hub'],
|
||||
# 'quantities': ['2', '1', '3'],
|
||||
# 'unit_prices': ['$45.99', '$29.99', '$35.50'],
|
||||
# 'subtotal': '$228.46',
|
||||
# 'tax': '$18.28',
|
||||
# 'total': '$246.74',
|
||||
# 'status': 'processing'
|
||||
# }]
|
||||
# }
|
||||
```
|
||||
|
||||
## Confidence Scores and Character Positions
|
||||
|
||||
You can include confidence scores and character-level start/end positions for structured extraction:
|
||||
|
||||
```python
|
||||
# Extract with confidence scores
|
||||
text = "The MacBook Pro costs $1999 and features M3 chip, 16GB RAM, and 512GB storage."
|
||||
results = extractor.extract_json(
|
||||
text,
|
||||
{
|
||||
"product": [
|
||||
"name::str",
|
||||
"price",
|
||||
"features"
|
||||
]
|
||||
},
|
||||
include_confidence=True
|
||||
)
|
||||
# Output: {
|
||||
# 'product': [{
|
||||
# 'name': {'text': 'MacBook Pro', 'confidence': 0.95},
|
||||
# 'price': [{'text': '$1999', 'confidence': 0.92}],
|
||||
# 'features': [
|
||||
# {'text': 'M3 chip', 'confidence': 0.88},
|
||||
# {'text': '16GB RAM', 'confidence': 0.90},
|
||||
# {'text': '512GB storage', 'confidence': 0.87}
|
||||
# ]
|
||||
# }]
|
||||
# }
|
||||
|
||||
# Extract with character positions (spans)
|
||||
results = extractor.extract_json(
|
||||
text,
|
||||
{
|
||||
"product": [
|
||||
"name::str",
|
||||
"price"
|
||||
]
|
||||
},
|
||||
include_spans=True
|
||||
)
|
||||
# Output: {
|
||||
# 'product': [{
|
||||
# 'name': {'text': 'MacBook Pro', 'start': 4, 'end': 15},
|
||||
# 'price': [{'text': '$1999', 'start': 22, 'end': 27}]
|
||||
# }]
|
||||
# }
|
||||
|
||||
# Extract with both confidence and spans
|
||||
results = extractor.extract_json(
|
||||
text,
|
||||
{
|
||||
"product": [
|
||||
"name::str",
|
||||
"price",
|
||||
"features"
|
||||
]
|
||||
},
|
||||
include_confidence=True,
|
||||
include_spans=True
|
||||
)
|
||||
# Output: {
|
||||
# 'product': [{
|
||||
# 'name': {'text': 'MacBook Pro', 'confidence': 0.95, 'start': 4, 'end': 15},
|
||||
# 'price': [{'text': '$1999', 'confidence': 0.92, 'start': 22, 'end': 27}],
|
||||
# 'features': [
|
||||
# {'text': 'M3 chip', 'confidence': 0.88, 'start': 32, 'end': 39},
|
||||
# {'text': '16GB RAM', 'confidence': 0.90, 'start': 41, 'end': 49},
|
||||
# {'text': '512GB storage', 'confidence': 0.87, 'start': 55, 'end': 68}
|
||||
# ]
|
||||
# }]
|
||||
# }
|
||||
```
|
||||
|
||||
**Note**: When `include_spans` or `include_confidence` is True:
|
||||
- **String fields** (`dtype="str"`): Return dicts with `{'text': '...', 'confidence': 0.95, 'start': 0, 'end': 5}` (or subset)
|
||||
- **List fields** (`dtype="list"`): Return lists of dicts, each with text, confidence, and positions
|
||||
- **Default** (both False): Returns simple strings or lists of strings
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Data Types
|
||||
|
||||
- Use `::str` for single values (IDs, names, amounts)
|
||||
- Use `::list` or default for multiple values (features, items, tags)
|
||||
- Use choices `[opt1|opt2|opt3]` for standardized values
|
||||
- Add descriptions for complex or domain-specific fields
|
||||
|
||||
### Quick Decision Guide
|
||||
|
||||
**Use `extract_json()`** for:
|
||||
- Structure-only extraction
|
||||
- Quick data parsing
|
||||
- Single extraction task
|
||||
|
||||
**Use `create_schema().extract()`** for:
|
||||
- Multi-task scenarios (entities + structures + classification)
|
||||
- When you need entities or classification alongside structures
|
||||
- Complex extraction pipelines
|
||||
357
packages/GLiNER2/tutorial/4-combined.md
Normal file
357
packages/GLiNER2/tutorial/4-combined.md
Normal file
@ -0,0 +1,357 @@
|
||||
# GLiNER2 Combining Schemas Tutorial
|
||||
|
||||
## Table of Contents
|
||||
- [Why Combine Schemas](#why-combine-schemas)
|
||||
- [Basic Combinations](#basic-combinations)
|
||||
- [Advanced Multi-Task Schemas](#advanced-multi-task-schemas)
|
||||
- [Real-World Applications](#real-world-applications)
|
||||
|
||||
## Why Combine Schemas
|
||||
|
||||
Combining schemas allows you to:
|
||||
- Extract multiple types of information in one pass
|
||||
- Maintain context between different extraction tasks
|
||||
- Improve efficiency by avoiding multiple model calls
|
||||
- Build comprehensive information extraction pipelines
|
||||
|
||||
## Basic Combinations
|
||||
|
||||
### Entities + Classification
|
||||
|
||||
```python
|
||||
from gliner2 import GLiNER2
|
||||
|
||||
extractor = GLiNER2.from_pretrained("your-model-name")
|
||||
|
||||
# Sentiment analysis with entity extraction
|
||||
schema = (extractor.create_schema()
|
||||
.entities(["person", "product", "company"])
|
||||
.classification("sentiment", ["positive", "negative", "neutral"])
|
||||
.classification("category", ["review", "news", "opinion"])
|
||||
)
|
||||
|
||||
text = "Tim Cook announced that Apple's new iPhone is exceeding sales expectations."
|
||||
results = extractor.extract(text, schema)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'person': ['Tim Cook'],
|
||||
# 'product': ['iPhone'],
|
||||
# 'company': ['Apple']
|
||||
# },
|
||||
# 'sentiment': 'positive',
|
||||
# 'category': 'news'
|
||||
# }
|
||||
```
|
||||
|
||||
### Entities + Structures
|
||||
|
||||
```python
|
||||
schema = (extractor.create_schema()
|
||||
.entities({
|
||||
"person": "Names of people mentioned",
|
||||
"date": "Dates and time references"
|
||||
})
|
||||
.structure("appointment")
|
||||
.field("patient", dtype="str")
|
||||
.field("doctor", dtype="str")
|
||||
.field("date")
|
||||
.field("time")
|
||||
.field("type", dtype="str", choices=["checkup", "followup", "consultation"])
|
||||
)
|
||||
|
||||
text = """
|
||||
Dr. Sarah Johnson confirmed the appointment with John Smith for
|
||||
March 15th at 2:30 PM. This will be a follow-up consultation
|
||||
regarding his previous visit on February 1st.
|
||||
"""
|
||||
results = extractor.extract(text, schema)
|
||||
```
|
||||
|
||||
### Classification + Structures
|
||||
|
||||
```python
|
||||
schema = (extractor.create_schema()
|
||||
.classification("email_type",
|
||||
["order_confirmation", "shipping_update", "promotional", "support"])
|
||||
.classification("priority", ["urgent", "normal", "low"])
|
||||
.structure("order_info")
|
||||
.field("order_number", dtype="str")
|
||||
.field("items")
|
||||
.field("total", dtype="str")
|
||||
.field("status", dtype="str",
|
||||
choices=["pending", "processing", "shipped", "delivered"])
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced Multi-Task Schemas
|
||||
|
||||
### Complete Document Analysis
|
||||
|
||||
```python
|
||||
# Comprehensive invoice extraction
|
||||
invoice_schema = (extractor.create_schema()
|
||||
# Document classification
|
||||
.classification("document_type",
|
||||
["invoice", "credit_note", "purchase_order", "receipt"])
|
||||
.classification("payment_status",
|
||||
["paid", "unpaid", "partial", "overdue"])
|
||||
|
||||
# Key entities
|
||||
.entities({
|
||||
"company": "Company names (buyer or seller)",
|
||||
"person": "Contact person names",
|
||||
"date": "Important dates",
|
||||
"amount": "Monetary amounts"
|
||||
})
|
||||
|
||||
# Structured information
|
||||
.structure("invoice_header")
|
||||
.field("invoice_number", dtype="str")
|
||||
.field("issue_date", dtype="str")
|
||||
.field("due_date", dtype="str")
|
||||
.field("vendor_name", dtype="str")
|
||||
.field("customer_name", dtype="str")
|
||||
|
||||
.structure("line_item")
|
||||
.field("description", dtype="str")
|
||||
.field("quantity")
|
||||
.field("unit_price")
|
||||
.field("amount")
|
||||
.field("tax_rate", dtype="str", choices=["0%", "5%", "10%", "20%"])
|
||||
|
||||
.structure("payment_info")
|
||||
.field("method", dtype="str",
|
||||
choices=["bank_transfer", "credit_card", "check", "cash"])
|
||||
.field("terms", description="Payment terms like NET30")
|
||||
.field("bank_details", dtype="list")
|
||||
)
|
||||
```
|
||||
|
||||
### Customer Feedback Analysis
|
||||
|
||||
```python
|
||||
feedback_schema = (extractor.create_schema()
|
||||
# Overall classifications
|
||||
.classification("sentiment", ["positive", "negative", "neutral", "mixed"])
|
||||
.classification("intent", {
|
||||
"complaint": "Customer expressing dissatisfaction",
|
||||
"compliment": "Customer expressing satisfaction",
|
||||
"suggestion": "Customer providing improvement ideas",
|
||||
"question": "Customer asking for information"
|
||||
}, multi_label=True)
|
||||
|
||||
# Extract mentioned entities
|
||||
.entities({
|
||||
"product": "Products or services mentioned",
|
||||
"feature": "Specific features discussed",
|
||||
"competitor": "Competing products mentioned",
|
||||
"price_mention": "Price points or cost references"
|
||||
})
|
||||
|
||||
# Structured feedback components
|
||||
.structure("issue")
|
||||
.field("problem", dtype="str")
|
||||
.field("severity", dtype="str", choices=["critical", "major", "minor"])
|
||||
.field("affected_area", dtype="list")
|
||||
|
||||
.structure("suggestion")
|
||||
.field("improvement", dtype="str")
|
||||
.field("benefit", description="Expected benefit of the suggestion")
|
||||
)
|
||||
```
|
||||
|
||||
### News Article Analysis
|
||||
|
||||
```python
|
||||
news_schema = (extractor.create_schema()
|
||||
# Article metadata
|
||||
.classification("category",
|
||||
["politics", "business", "technology", "sports", "entertainment"])
|
||||
.classification("bias", ["left", "center", "right", "neutral"])
|
||||
.classification("factuality", ["fact", "opinion", "analysis", "speculation"])
|
||||
|
||||
# Key entities
|
||||
.entities({
|
||||
"person": "People mentioned in the article",
|
||||
"organization": "Companies, agencies, or groups",
|
||||
"location": "Places, cities, or countries",
|
||||
"event": "Named events or incidents"
|
||||
})
|
||||
|
||||
# Structured content
|
||||
.structure("quote")
|
||||
.field("speaker", dtype="str")
|
||||
.field("statement", dtype="str")
|
||||
.field("context", description="Context of the quote")
|
||||
|
||||
.structure("claim")
|
||||
.field("statement", dtype="str")
|
||||
.field("source", dtype="str")
|
||||
.field("evidence", dtype="list")
|
||||
)
|
||||
```
|
||||
|
||||
## Real-World Applications
|
||||
|
||||
### E-commerce Product Listing
|
||||
|
||||
```python
|
||||
product_schema = (extractor.create_schema()
|
||||
# Listing classification
|
||||
.classification("condition", ["new", "used", "refurbished", "for_parts"])
|
||||
.classification("listing_type", ["buy_now", "auction", "best_offer"])
|
||||
|
||||
# Extract key entities
|
||||
.entities({
|
||||
"brand": "Product brand or manufacturer",
|
||||
"model": "Specific model name or number",
|
||||
"color": "Product colors mentioned",
|
||||
"size": "Size specifications"
|
||||
})
|
||||
|
||||
# Product details
|
||||
.structure("product")
|
||||
.field("title", dtype="str")
|
||||
.field("price", dtype="str")
|
||||
.field("features", dtype="list")
|
||||
.field("category", dtype="str")
|
||||
|
||||
# Shipping information
|
||||
.structure("shipping")
|
||||
.field("method", dtype="list",
|
||||
choices=["standard", "express", "overnight", "international"])
|
||||
.field("cost", dtype="str")
|
||||
.field("delivery_time", description="Estimated delivery timeframe")
|
||||
|
||||
# Seller information
|
||||
.structure("seller")
|
||||
.field("name", dtype="str")
|
||||
.field("rating", dtype="str")
|
||||
.field("location", dtype="str")
|
||||
)
|
||||
```
|
||||
|
||||
### Healthcare Clinical Note
|
||||
|
||||
```python
|
||||
clinical_schema = (extractor.create_schema()
|
||||
# Note classification
|
||||
.classification("visit_type",
|
||||
["initial_consultation", "follow_up", "emergency", "routine_checkup"])
|
||||
.classification("urgency", ["urgent", "routine", "elective"])
|
||||
|
||||
# Medical entities
|
||||
.entities({
|
||||
"symptom": "Patient reported symptoms",
|
||||
"diagnosis": "Medical diagnoses or conditions",
|
||||
"medication": "Prescribed or mentioned medications",
|
||||
"procedure": "Medical procedures or tests",
|
||||
"body_part": "Anatomical references"
|
||||
})
|
||||
|
||||
# Patient information
|
||||
.structure("patient_info")
|
||||
.field("name", dtype="str")
|
||||
.field("age", dtype="str")
|
||||
.field("gender", dtype="str", choices=["male", "female", "other"])
|
||||
.field("chief_complaint", dtype="str")
|
||||
|
||||
# Clinical findings
|
||||
.structure("vital_signs")
|
||||
.field("blood_pressure", dtype="str")
|
||||
.field("heart_rate", dtype="str")
|
||||
.field("temperature", dtype="str")
|
||||
.field("respiratory_rate", dtype="str")
|
||||
|
||||
# Treatment plan
|
||||
.structure("prescription")
|
||||
.field("medication", dtype="str")
|
||||
.field("dosage", dtype="str")
|
||||
.field("frequency")
|
||||
.field("duration")
|
||||
.field("route", dtype="str", choices=["oral", "IV", "topical", "injection"])
|
||||
)
|
||||
```
|
||||
|
||||
### Legal Document Analysis
|
||||
|
||||
```python
|
||||
legal_schema = (extractor.create_schema()
|
||||
# Document classification
|
||||
.classification("document_type",
|
||||
["contract", "memorandum", "brief", "motion", "order"])
|
||||
.classification("jurisdiction",
|
||||
["federal", "state", "local", "international"])
|
||||
|
||||
# Legal entities
|
||||
.entities({
|
||||
"party": "Parties involved (plaintiff, defendant, etc.)",
|
||||
"attorney": "Legal representatives",
|
||||
"judge": "Judicial officers",
|
||||
"statute": "Laws or regulations cited",
|
||||
"case_citation": "Referenced legal cases"
|
||||
})
|
||||
|
||||
# Contract terms
|
||||
.structure("contract_term")
|
||||
.field("clause_type", dtype="str",
|
||||
choices=["payment", "delivery", "warranty", "liability", "termination"])
|
||||
.field("obligation", dtype="str")
|
||||
.field("party_responsible", dtype="str")
|
||||
.field("deadline")
|
||||
|
||||
# Legal claims
|
||||
.structure("claim")
|
||||
.field("type", dtype="str")
|
||||
.field("plaintiff", dtype="str")
|
||||
.field("defendant", dtype="str")
|
||||
.field("amount", dtype="str")
|
||||
.field("basis", description="Legal basis for the claim")
|
||||
)
|
||||
```
|
||||
|
||||
## Using Confidence Scores and Character Positions with Combined Schemas
|
||||
|
||||
When using combined schemas, `include_confidence` and `include_spans` parameters apply to all extraction types:
|
||||
|
||||
```python
|
||||
schema = (extractor.create_schema()
|
||||
.entities(["person", "company"])
|
||||
.classification("sentiment", ["positive", "negative", "neutral"])
|
||||
.relations(["works_for"])
|
||||
.structure("product")
|
||||
.field("name", dtype="str")
|
||||
.field("price", dtype="str")
|
||||
)
|
||||
|
||||
text = "Tim Cook works for Apple. The iPhone 15 costs $999. This is exciting!"
|
||||
results = extractor.extract(
|
||||
text,
|
||||
schema,
|
||||
include_confidence=True,
|
||||
include_spans=True
|
||||
)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'person': [
|
||||
# {'text': 'Tim Cook', 'confidence': 0.95, 'start': 0, 'end': 8}
|
||||
# ],
|
||||
# 'company': [
|
||||
# {'text': 'Apple', 'confidence': 0.92, 'start': 20, 'end': 25}
|
||||
# ]
|
||||
# },
|
||||
# 'sentiment': {'label': 'positive', 'confidence': 0.88},
|
||||
# 'relation_extraction': {
|
||||
# 'works_for': [{
|
||||
# 'head': {'text': 'Tim Cook', 'confidence': 0.95, 'start': 0, 'end': 8},
|
||||
# 'tail': {'text': 'Apple', 'confidence': 0.92, 'start': 20, 'end': 25}
|
||||
# }]
|
||||
# },
|
||||
# 'product': [{
|
||||
# 'name': {'text': 'iPhone 15', 'confidence': 0.90, 'start': 30, 'end': 39},
|
||||
# 'price': {'text': '$999', 'confidence': 0.88, 'start': 46, 'end': 51}
|
||||
# }]
|
||||
# }
|
||||
```
|
||||
|
||||
**Note**: The `include_confidence` and `include_spans` parameters work consistently across all extraction types (entities, classifications, relations, and structures) when using combined schemas.
|
||||
112
packages/GLiNER2/tutorial/5-validator.md
Normal file
112
packages/GLiNER2/tutorial/5-validator.md
Normal file
@ -0,0 +1,112 @@
|
||||
# GLiNER2 Regex Validators
|
||||
|
||||
Regex validators filter extracted spans to ensure they match expected patterns, improving extraction quality and reducing false positives.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from gliner2 import GLiNER2, RegexValidator
|
||||
|
||||
extractor = GLiNER2.from_pretrained("your-model")
|
||||
|
||||
# Create validator and apply to field
|
||||
email_validator = RegexValidator(r"^[\w\.-]+@[\w\.-]+\.\w+$")
|
||||
schema = (extractor.create_schema()
|
||||
.structure("contact")
|
||||
.field("email", dtype="str", validators=[email_validator])
|
||||
)
|
||||
```
|
||||
|
||||
## RegexValidator Parameters
|
||||
|
||||
- **pattern**: Regex pattern (string or compiled Pattern)
|
||||
- **mode**: `"full"` (exact match) or `"partial"` (substring match)
|
||||
- **exclude**: `False` (keep matches) or `True` (exclude matches)
|
||||
- **flags**: Regex flags like `re.IGNORECASE` (for string patterns only)
|
||||
|
||||
## Examples
|
||||
|
||||
### Email Validation
|
||||
```python
|
||||
email_validator = RegexValidator(r"^[\w\.-]+@[\w\.-]+\.\w+$")
|
||||
|
||||
text = "Contact: john@company.com, not-an-email, jane@domain.org"
|
||||
# Output: ['john@company.com', 'jane@domain.org']
|
||||
```
|
||||
|
||||
### Phone Numbers (US Format)
|
||||
```python
|
||||
phone_validator = RegexValidator(r"\(\d{3}\)\s\d{3}-\d{4}", mode="partial")
|
||||
|
||||
text = "Call (555) 123-4567 or 5551234567"
|
||||
# Output: ['(555) 123-4567'] # Second number filtered out
|
||||
```
|
||||
|
||||
### URLs Only
|
||||
```python
|
||||
url_validator = RegexValidator(r"^https?://", mode="partial")
|
||||
|
||||
text = "Visit https://example.com or www.site.com"
|
||||
# Output: ['https://example.com'] # www.site.com filtered out
|
||||
```
|
||||
|
||||
### Exclude Test Data
|
||||
```python
|
||||
no_test_validator = RegexValidator(r"^(test|demo|sample)", exclude=True, flags=re.IGNORECASE)
|
||||
|
||||
text = "Products: iPhone, Test Phone, Samsung Galaxy"
|
||||
# Output: ['iPhone', 'Samsung Galaxy'] # Test Phone excluded
|
||||
```
|
||||
|
||||
### Length Constraints
|
||||
```python
|
||||
length_validator = RegexValidator(r"^.{5,50}$") # 5-50 characters
|
||||
|
||||
text = "Names: Jo, Alexander, A Very Long Name That Exceeds Fifty Characters"
|
||||
# Output: ['Alexander'] # Others filtered by length
|
||||
```
|
||||
|
||||
### Multiple Validators
|
||||
```python
|
||||
# All validators must pass
|
||||
username_validators = [
|
||||
RegexValidator(r"^[a-zA-Z0-9_]+$"), # Alphanumeric + underscore
|
||||
RegexValidator(r"^.{3,20}$"), # 3-20 characters
|
||||
RegexValidator(r"^(?!admin)", exclude=True, flags=re.IGNORECASE) # No "admin"
|
||||
]
|
||||
|
||||
schema = (extractor.create_schema()
|
||||
.structure("user")
|
||||
.field("username", dtype="str", validators=username_validators)
|
||||
)
|
||||
|
||||
text = "Users: ab, john_doe, user@domain, admin, valid_user123"
|
||||
# Output: ['john_doe', 'valid_user123']
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
| Use Case | Pattern | Mode |
|
||||
|----------|---------|------|
|
||||
| Email | `r"^[\w\.-]+@[\w\.-]+\.\w+$"` | full |
|
||||
| Phone (US) | `r"\(\d{3}\)\s\d{3}-\d{4}"` | partial |
|
||||
| URL | `r"^https?://"` | partial |
|
||||
| Numbers only | `r"^\d+$"` | full |
|
||||
| No spaces | `r"^\S+$"` | full |
|
||||
| Min length | `r"^.{5,}$"` | full |
|
||||
| Alphanumeric | `r"^[a-zA-Z0-9]+$"` | full |
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use specific patterns** - More specific = fewer false positives
|
||||
2. **Test your regex** - Validate patterns before deployment
|
||||
3. **Combine validators** - Chain multiple simple validators
|
||||
4. **Consider case sensitivity** - Use `re.IGNORECASE` when needed
|
||||
5. **Start simple** - Begin with basic patterns, refine as needed
|
||||
|
||||
## Performance Notes
|
||||
|
||||
- Validators run after span extraction but before formatting
|
||||
- Failed validation simply excludes the span (no errors)
|
||||
- Multiple validators use short-circuit evaluation (stops at first failure)
|
||||
- Compiled patterns are cached automatically
|
||||
643
packages/GLiNER2/tutorial/6-relation_extraction.md
Normal file
643
packages/GLiNER2/tutorial/6-relation_extraction.md
Normal file
@ -0,0 +1,643 @@
|
||||
# GLiNER2 Relation Extraction Tutorial
|
||||
|
||||
Learn how to extract relations between entities from text using GLiNER2's relation extraction capabilities.
|
||||
|
||||
## Table of Contents
|
||||
- [Basic Relation Extraction](#basic-relation-extraction)
|
||||
- [Multiple Relation Types](#multiple-relation-types)
|
||||
- [Relation Extraction with Descriptions](#relation-extraction-with-descriptions)
|
||||
- [Custom Thresholds](#custom-thresholds)
|
||||
- [Batch Processing](#batch-processing)
|
||||
- [Combining with Other Tasks](#combining-with-other-tasks)
|
||||
- [Real-World Examples](#real-world-examples)
|
||||
- [Best Practices](#best-practices)
|
||||
|
||||
## Basic Relation Extraction
|
||||
|
||||
### Simple Example
|
||||
|
||||
```python
|
||||
from gliner2 import GLiNER2
|
||||
|
||||
# Load model
|
||||
extractor = GLiNER2.from_pretrained("your-model-name")
|
||||
|
||||
# Extract relations
|
||||
text = "John works for Apple Inc. and lives in San Francisco."
|
||||
results = extractor.extract_relations(
|
||||
text,
|
||||
["works_for", "lives_in"]
|
||||
)
|
||||
print(results)
|
||||
# Output: {
|
||||
# 'relation_extraction': {
|
||||
# 'works_for': [('John', 'Apple Inc.')],
|
||||
# 'lives_in': [('John', 'San Francisco')]
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
### Using Schema Builder
|
||||
|
||||
```python
|
||||
# Same extraction using schema
|
||||
schema = extractor.create_schema().relations([
|
||||
"works_for", "lives_in"
|
||||
])
|
||||
results = extractor.extract(text, schema)
|
||||
```
|
||||
|
||||
### Understanding the Output Format
|
||||
|
||||
Relations are returned as tuples `(source, target)` grouped under the `relation_extraction` key. **All requested relation types are included in the output, even if no relations are found** (they appear as empty lists `[]`):
|
||||
|
||||
```python
|
||||
text = "Alice manages the Engineering team. Bob reports to Alice."
|
||||
results = extractor.extract_relations(
|
||||
text,
|
||||
["manages", "reports_to", "founded"] # Note: "founded" not found in text
|
||||
)
|
||||
# Output: {
|
||||
# 'relation_extraction': {
|
||||
# 'manages': [('Alice', 'Engineering team')],
|
||||
# 'reports_to': [('Bob', 'Alice')],
|
||||
# 'founded': [] # Empty list - relation type requested but not found
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
This ensures consistent output structure - all requested relation types will always be present in the results, making it easier to process the output programmatically.
|
||||
|
||||
## Multiple Relation Types
|
||||
|
||||
You can extract multiple relation types in a single call:
|
||||
|
||||
```python
|
||||
text = """
|
||||
Sarah founded TechCorp in 2020. She is married to Mike,
|
||||
who works at Google. TechCorp is located in Seattle.
|
||||
"""
|
||||
|
||||
results = extractor.extract_relations(
|
||||
text,
|
||||
["founded", "married_to", "works_at", "located_in"]
|
||||
)
|
||||
# Output: {
|
||||
# 'relation_extraction': {
|
||||
# 'founded': [('Sarah', 'TechCorp')],
|
||||
# 'married_to': [('Sarah', 'Mike')],
|
||||
# 'works_at': [('Mike', 'Google')],
|
||||
# 'located_in': [('TechCorp', 'Seattle')]
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
### Multiple Instances per Relation Type
|
||||
|
||||
GLiNER2 automatically extracts all relation instances found in the text:
|
||||
|
||||
```python
|
||||
text = """
|
||||
John works for Microsoft. Mary works for Google.
|
||||
Bob works for Apple. All three live in California.
|
||||
"""
|
||||
|
||||
results = extractor.extract_relations(
|
||||
text,
|
||||
["works_for", "lives_in"]
|
||||
)
|
||||
# Output: {
|
||||
# 'relation_extraction': {
|
||||
# 'works_for': [
|
||||
# ('John', 'Microsoft'),
|
||||
# ('Mary', 'Google'),
|
||||
# ('Bob', 'Apple')
|
||||
# ],
|
||||
# 'lives_in': [
|
||||
# ('John', 'California'),
|
||||
# ('Mary', 'California'),
|
||||
# ('Bob', 'California')
|
||||
# ]
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
## Relation Extraction with Descriptions
|
||||
|
||||
Providing descriptions helps improve extraction accuracy by clarifying the relation semantics:
|
||||
|
||||
```python
|
||||
schema = extractor.create_schema().relations({
|
||||
"works_for": "Employment relationship where person works at organization",
|
||||
"founded": "Founding relationship where person created organization",
|
||||
"acquired": "Acquisition relationship where company bought another company",
|
||||
"located_in": "Geographic relationship where entity is in a location"
|
||||
})
|
||||
|
||||
text = """
|
||||
Elon Musk founded SpaceX in 2002. SpaceX is located in Hawthorne, California.
|
||||
Tesla acquired SolarCity in 2016. Many engineers work for SpaceX.
|
||||
"""
|
||||
|
||||
results = extractor.extract(text, schema)
|
||||
```
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
```python
|
||||
schema = extractor.create_schema().relations({
|
||||
"works_for": {
|
||||
"description": "Employment or professional relationship",
|
||||
"threshold": 0.7 # Higher precision for employment relations
|
||||
},
|
||||
"located_in": {
|
||||
"description": "Geographic containment relationship",
|
||||
"threshold": 0.6 # Moderate threshold
|
||||
},
|
||||
"reports_to": {
|
||||
"description": "Organizational hierarchy relationship",
|
||||
"threshold": 0.8 # Very high precision
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
## Custom Thresholds
|
||||
|
||||
### Global Threshold
|
||||
|
||||
```python
|
||||
# High-precision relation extraction
|
||||
results = extractor.extract_relations(
|
||||
text,
|
||||
["acquired", "merged_with"],
|
||||
threshold=0.8 # High confidence required
|
||||
)
|
||||
```
|
||||
|
||||
### Per-Relation Thresholds
|
||||
|
||||
```python
|
||||
schema = extractor.create_schema().relations({
|
||||
"acquired": {
|
||||
"description": "Company acquisition relationship",
|
||||
"threshold": 0.9 # Very high precision
|
||||
},
|
||||
"partnered_with": {
|
||||
"description": "Partnership or collaboration relationship",
|
||||
"threshold": 0.6 # Moderate threshold
|
||||
},
|
||||
"competes_with": {
|
||||
"description": "Competitive relationship",
|
||||
"threshold": 0.5 # Lower threshold for implicit relations
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### With Confidence Scores and Character Positions
|
||||
|
||||
You can include confidence scores and character-level start/end positions for relation extractions:
|
||||
|
||||
```python
|
||||
# Extract relations with confidence scores
|
||||
text = "John works for Apple Inc. and lives in San Francisco."
|
||||
results = extractor.extract_relations(
|
||||
text,
|
||||
["works_for", "lives_in"],
|
||||
include_confidence=True
|
||||
)
|
||||
print(results)
|
||||
# Output: {
|
||||
# 'relation_extraction': {
|
||||
# 'works_for': [{
|
||||
# 'head': {'text': 'John', 'confidence': 0.95},
|
||||
# 'tail': {'text': 'Apple Inc.', 'confidence': 0.92}
|
||||
# }],
|
||||
# 'lives_in': [{
|
||||
# 'head': {'text': 'John', 'confidence': 0.94},
|
||||
# 'tail': {'text': 'San Francisco', 'confidence': 0.91}
|
||||
# }]
|
||||
# }
|
||||
# }
|
||||
|
||||
# Extract with character positions (spans)
|
||||
results = extractor.extract_relations(
|
||||
text,
|
||||
["works_for", "lives_in"],
|
||||
include_spans=True
|
||||
)
|
||||
print(results)
|
||||
# Output: {
|
||||
# 'relation_extraction': {
|
||||
# 'works_for': [{
|
||||
# 'head': {'text': 'John', 'start': 0, 'end': 4},
|
||||
# 'tail': {'text': 'Apple Inc.', 'start': 15, 'end': 25}
|
||||
# }],
|
||||
# 'lives_in': [{
|
||||
# 'head': {'text': 'John', 'start': 0, 'end': 4},
|
||||
# 'tail': {'text': 'San Francisco', 'start': 33, 'end': 46}
|
||||
# }]
|
||||
# }
|
||||
# }
|
||||
|
||||
# Extract with both confidence and spans
|
||||
results = extractor.extract_relations(
|
||||
text,
|
||||
["works_for", "lives_in"],
|
||||
include_confidence=True,
|
||||
include_spans=True
|
||||
)
|
||||
print(results)
|
||||
# Output: {
|
||||
# 'relation_extraction': {
|
||||
# 'works_for': [{
|
||||
# 'head': {'text': 'John', 'confidence': 0.95, 'start': 0, 'end': 4},
|
||||
# 'tail': {'text': 'Apple Inc.', 'confidence': 0.92, 'start': 15, 'end': 25}
|
||||
# }],
|
||||
# 'lives_in': [{
|
||||
# 'head': {'text': 'John', 'confidence': 0.94, 'start': 0, 'end': 4},
|
||||
# 'tail': {'text': 'San Francisco', 'confidence': 0.91, 'start': 33, 'end': 46}
|
||||
# }]
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
**Note**: When `include_spans` or `include_confidence` is True, relations are returned as dictionaries with `head` and `tail` keys, each containing the extracted text along with optional confidence scores and character positions. When both are False (default), relations are returned as simple tuples `(head, tail)`.
|
||||
|
||||
## Batch Processing
|
||||
|
||||
Process multiple texts efficiently:
|
||||
|
||||
```python
|
||||
texts = [
|
||||
"John works for Microsoft and lives in Seattle.",
|
||||
"Sarah founded TechStartup in 2020.",
|
||||
"Bob reports to Alice at Google."
|
||||
]
|
||||
|
||||
results = extractor.batch_extract_relations(
|
||||
texts,
|
||||
["works_for", "founded", "reports_to", "lives_in"],
|
||||
batch_size=8
|
||||
)
|
||||
# Output: [
|
||||
# {
|
||||
# 'relation_extraction': {
|
||||
# 'works_for': [('John', 'Microsoft')],
|
||||
# 'lives_in': [('John', 'Seattle')],
|
||||
# 'founded': [], # Not found in first text
|
||||
# 'reports_to': [] # Not found in first text
|
||||
# }
|
||||
# },
|
||||
# {
|
||||
# 'relation_extraction': {
|
||||
# 'works_for': [], # Not found in second text
|
||||
# 'founded': [('Sarah', 'TechStartup')],
|
||||
# 'reports_to': [], # Not found in second text
|
||||
# 'lives_in': [] # Not found in second text
|
||||
# }
|
||||
# },
|
||||
# {
|
||||
# 'relation_extraction': {
|
||||
# 'works_for': [('Alice', 'Google')],
|
||||
# 'reports_to': [('Bob', 'Alice')],
|
||||
# 'founded': [], # Not found in third text
|
||||
# 'lives_in': [] # Not found in third text
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
```
|
||||
|
||||
**Note**: All requested relation types appear in each result, even if empty. This ensures consistent structure across all batch results, making it easier to process programmatically.
|
||||
|
||||
## Combining with Other Tasks
|
||||
|
||||
Relation extraction can be combined with entity extraction, classification, and structured extraction:
|
||||
|
||||
### Relations + Entities
|
||||
|
||||
```python
|
||||
schema = (extractor.create_schema()
|
||||
.entities(["person", "organization", "location"])
|
||||
.relations(["works_for", "located_in"])
|
||||
)
|
||||
|
||||
text = "Tim Cook works for Apple Inc., which is located in Cupertino, California."
|
||||
results = extractor.extract(text, schema)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'person': ['Tim Cook'],
|
||||
# 'organization': ['Apple Inc.'],
|
||||
# 'location': ['Cupertino', 'California']
|
||||
# },
|
||||
# 'relation_extraction': {
|
||||
# 'works_for': [('Tim Cook', 'Apple Inc.')],
|
||||
# 'located_in': [('Apple Inc.', 'Cupertino')]
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
### Relations + Classification + Structures
|
||||
|
||||
```python
|
||||
schema = (extractor.create_schema()
|
||||
.classification("document_type", ["news", "report", "announcement"])
|
||||
.entities(["person", "company"])
|
||||
.relations(["works_for", "acquired"])
|
||||
.structure("event")
|
||||
.field("date", dtype="str")
|
||||
.field("description", dtype="str")
|
||||
)
|
||||
|
||||
text = """
|
||||
BREAKING: Microsoft announced today that it acquired GitHub.
|
||||
Satya Nadella, CEO of Microsoft, confirmed the deal.
|
||||
The acquisition was finalized on October 26, 2018.
|
||||
"""
|
||||
|
||||
results = extractor.extract(text, schema)
|
||||
```
|
||||
|
||||
## Real-World Examples
|
||||
|
||||
### Organizational Relationships
|
||||
|
||||
```python
|
||||
org_schema = extractor.create_schema().relations({
|
||||
"reports_to": "Direct reporting relationship in organizational hierarchy",
|
||||
"manages": "Management relationship where person manages team/department",
|
||||
"works_for": "Employment relationship",
|
||||
"founded": "Founding relationship",
|
||||
"acquired": "Company acquisition relationship"
|
||||
})
|
||||
|
||||
text = """
|
||||
Sundar Pichai is the CEO of Google. He reports to the board of directors.
|
||||
Google acquired YouTube in 2006. Many engineers work for Google.
|
||||
"""
|
||||
|
||||
results = extractor.extract(text, org_schema)
|
||||
# Output: {
|
||||
# 'relation_extraction': {
|
||||
# 'reports_to': [('Sundar Pichai', 'board of directors')],
|
||||
# 'works_for': [('engineers', 'Google')],
|
||||
# 'acquired': [('Google', 'YouTube')]
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
### Medical Relationships
|
||||
|
||||
```python
|
||||
medical_schema = extractor.create_schema().relations({
|
||||
"treats": "Medical treatment relationship between doctor and patient",
|
||||
"prescribed_for": "Prescription relationship between medication and condition",
|
||||
"causes": "Causal relationship between condition and symptom",
|
||||
"located_in": "Anatomical location relationship"
|
||||
})
|
||||
|
||||
text = """
|
||||
Dr. Smith treats patients with diabetes. Metformin is prescribed for Type 2 Diabetes.
|
||||
High blood sugar causes frequent urination. The pancreas is located in the abdomen.
|
||||
"""
|
||||
|
||||
results = extractor.extract(text, medical_schema)
|
||||
```
|
||||
|
||||
### Financial Relationships
|
||||
|
||||
```python
|
||||
finance_schema = extractor.create_schema().relations({
|
||||
"invested_in": "Investment relationship between investor and company",
|
||||
"acquired": "Company acquisition relationship",
|
||||
"merged_with": "Merger relationship between companies",
|
||||
"owns": "Ownership relationship"
|
||||
})
|
||||
|
||||
text = """
|
||||
SoftBank invested in Uber in 2018. Microsoft acquired LinkedIn in 2016.
|
||||
Disney merged with 21st Century Fox. Berkshire Hathaway owns Geico.
|
||||
"""
|
||||
|
||||
results = extractor.extract(text, finance_schema)
|
||||
```
|
||||
|
||||
### Geographic Relationships
|
||||
|
||||
```python
|
||||
geo_schema = extractor.create_schema().relations({
|
||||
"located_in": "Geographic containment (city in country, etc.)",
|
||||
"borders": "Geographic adjacency relationship",
|
||||
"capital_of": "Capital city relationship",
|
||||
"flows_through": "River or waterway relationship"
|
||||
})
|
||||
|
||||
text = """
|
||||
Paris is the capital of France. France borders Germany and Spain.
|
||||
The Seine flows through Paris. Paris is located in France.
|
||||
"""
|
||||
|
||||
results = extractor.extract(text, geo_schema)
|
||||
```
|
||||
|
||||
### Family Relationships
|
||||
|
||||
```python
|
||||
family_schema = extractor.create_schema().relations({
|
||||
"married_to": "Marriage relationship",
|
||||
"parent_of": "Parent-child relationship",
|
||||
"sibling_of": "Sibling relationship",
|
||||
"related_to": "General family relationship"
|
||||
})
|
||||
|
||||
text = """
|
||||
John is married to Mary. They are parents of two children: Alice and Bob.
|
||||
Alice and Bob are siblings. Mary is related to her sister Sarah.
|
||||
"""
|
||||
|
||||
results = extractor.extract(text, family_schema)
|
||||
```
|
||||
|
||||
### Academic Relationships
|
||||
|
||||
```python
|
||||
academic_schema = extractor.create_schema().relations({
|
||||
"authored": "Publication relationship between author and paper",
|
||||
"cited": "Citation relationship between papers",
|
||||
"supervised": "Academic supervision relationship",
|
||||
"affiliated_with": "Institutional affiliation relationship"
|
||||
})
|
||||
|
||||
text = """
|
||||
Dr. Johnson authored the paper on machine learning. The paper cited
|
||||
previous work by Dr. Smith. Dr. Johnson supervises graduate students
|
||||
at MIT, where she is affiliated with the Computer Science department.
|
||||
"""
|
||||
|
||||
results = extractor.extract(text, academic_schema)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Clear, Specific Relation Names
|
||||
|
||||
```python
|
||||
# Good - Clear and specific
|
||||
schema.relations(["works_for", "reports_to", "manages"])
|
||||
|
||||
# Less ideal - Too generic
|
||||
schema.relations(["related", "connected", "linked"])
|
||||
```
|
||||
|
||||
### 2. Provide Descriptions for Ambiguous Relations
|
||||
|
||||
```python
|
||||
# Good - Clear descriptions
|
||||
schema.relations({
|
||||
"works_for": "Employment relationship where person works at organization",
|
||||
"consulted_for": "Consulting relationship where person provides services to organization"
|
||||
})
|
||||
|
||||
# Less ideal - No context
|
||||
schema.relations(["works_for", "consulted_for"])
|
||||
```
|
||||
|
||||
### 3. Set Appropriate Thresholds
|
||||
|
||||
```python
|
||||
# High precision for critical relations
|
||||
schema.relations({
|
||||
"acquired": {
|
||||
"description": "Company acquisition",
|
||||
"threshold": 0.9 # Very high precision
|
||||
},
|
||||
"partnered_with": {
|
||||
"description": "Partnership relationship",
|
||||
"threshold": 0.6 # Moderate threshold
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### 4. Combine with Entity Extraction
|
||||
|
||||
```python
|
||||
# Extract both entities and relations for better context
|
||||
schema = (extractor.create_schema()
|
||||
.entities(["person", "organization"])
|
||||
.relations(["works_for", "founded"])
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Use Batch Processing for Multiple Texts
|
||||
|
||||
```python
|
||||
# Efficient batch processing
|
||||
results = extractor.batch_extract_relations(
|
||||
texts,
|
||||
relation_types,
|
||||
batch_size=8 # Adjust based on your hardware
|
||||
)
|
||||
```
|
||||
|
||||
### 6. Handle Multiple Instances
|
||||
|
||||
```python
|
||||
# GLiNER2 automatically extracts all instances
|
||||
text = "John works for Apple. Mary works for Google. Bob works for Microsoft."
|
||||
results = extractor.extract_relations(text, ["works_for"])
|
||||
# Returns all three work relationships
|
||||
```
|
||||
|
||||
### 7. Handle Empty Relations
|
||||
|
||||
All requested relation types are always included in the output, even if empty:
|
||||
|
||||
```python
|
||||
results = extractor.extract_relations(
|
||||
"John works for Microsoft.",
|
||||
["works_for", "founded", "acquired"]
|
||||
)
|
||||
# Output: {
|
||||
# 'relation_extraction': {
|
||||
# 'works_for': [('John', 'Microsoft')],
|
||||
# 'founded': [], # Empty - not found in text
|
||||
# 'acquired': [] # Empty - not found in text
|
||||
# }
|
||||
# }
|
||||
|
||||
# This makes it easy to check for relations programmatically:
|
||||
for rel_type, rels in results['relation_extraction'].items():
|
||||
if rels: # Non-empty
|
||||
print(f"Found {len(rels)} {rel_type} relations")
|
||||
else: # Empty
|
||||
print(f"No {rel_type} relations found")
|
||||
```
|
||||
|
||||
### 7. Validate Relation Direction
|
||||
|
||||
Relations are directional tuples `(source, target)`:
|
||||
- `works_for`: (person, organization)
|
||||
- `located_in`: (entity, location)
|
||||
- `reports_to`: (subordinate, manager)
|
||||
- `manages`: (manager, team)
|
||||
|
||||
Make sure your relation names match the expected direction.
|
||||
|
||||
## Common Use Cases
|
||||
|
||||
### Knowledge Graph Construction
|
||||
|
||||
```python
|
||||
# Extract entities and relations for knowledge graph
|
||||
schema = (extractor.create_schema()
|
||||
.entities(["person", "organization", "location", "product"])
|
||||
.relations([
|
||||
"works_for", "founded", "located_in", "created",
|
||||
"acquired", "partnered_with"
|
||||
])
|
||||
)
|
||||
|
||||
# Process documents to build knowledge graph
|
||||
documents = [...] # Your documents
|
||||
all_relations = []
|
||||
all_entities = []
|
||||
|
||||
for doc in documents:
|
||||
results = extractor.extract(doc, schema)
|
||||
all_relations.append(results.get("relation_extraction", {}))
|
||||
all_entities.append(results.get("entities", {}))
|
||||
```
|
||||
|
||||
### Relationship Analysis
|
||||
|
||||
```python
|
||||
# Analyze organizational structures
|
||||
org_texts = [...] # Organizational documents
|
||||
results = extractor.batch_extract_relations(
|
||||
org_texts,
|
||||
["reports_to", "manages", "works_for", "collaborates_with"],
|
||||
batch_size=8
|
||||
)
|
||||
|
||||
# Analyze relationship patterns
|
||||
for result in results:
|
||||
relations = result.get("relation_extraction", {})
|
||||
# Process relations for analysis
|
||||
```
|
||||
|
||||
### Document Understanding
|
||||
|
||||
```python
|
||||
# Comprehensive document understanding
|
||||
schema = (extractor.create_schema()
|
||||
.classification("document_type", ["contract", "report", "email"])
|
||||
.entities(["person", "organization", "date", "amount"])
|
||||
.relations(["signed_by", "involves", "dated", "worth"])
|
||||
.structure("contract_term")
|
||||
.field("term", dtype="str")
|
||||
.field("value", dtype="str")
|
||||
)
|
||||
|
||||
# Extract all information types in one pass
|
||||
results = extractor.extract(document_text, schema)
|
||||
```
|
||||
|
||||
514
packages/GLiNER2/tutorial/7-api.md
Normal file
514
packages/GLiNER2/tutorial/7-api.md
Normal file
@ -0,0 +1,514 @@
|
||||
# GLiNER2 API Extractor
|
||||
|
||||
Use GLiNER2 through a cloud API without loading models locally. Perfect for production deployments, low-memory environments, or when you need instant access without GPU setup.
|
||||
|
||||
## Table of Contents
|
||||
- [Getting Started](#getting-started)
|
||||
- [Basic Usage](#basic-usage)
|
||||
- [Entity Extraction](#entity-extraction)
|
||||
- [Text Classification](#text-classification)
|
||||
- [Structured Extraction](#structured-extraction)
|
||||
- [Relation Extraction](#relation-extraction)
|
||||
- [Combined Schemas](#combined-schemas)
|
||||
- [Batch Processing](#batch-processing)
|
||||
- [Confidence Scores](#confidence-scores)
|
||||
- [Error Handling](#error-handling)
|
||||
- [API vs Local](#api-vs-local)
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Get Your API Key
|
||||
|
||||
1. Visit [gliner.pioneer.ai](https://gliner.pioneer.ai)
|
||||
2. Sign up or log in to your account
|
||||
3. Navigate to API Keys section
|
||||
4. Generate a new API key
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install gliner2
|
||||
```
|
||||
|
||||
### Set Your API Key
|
||||
|
||||
**Option 1: Environment Variable (Recommended)**
|
||||
```bash
|
||||
export PIONEER_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
**Option 2: Pass Directly**
|
||||
```python
|
||||
extractor = GLiNER2.from_api(api_key="your-api-key-here")
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
from gliner2 import GLiNER2
|
||||
|
||||
# Load from API (uses PIONEER_API_KEY environment variable)
|
||||
extractor = GLiNER2.from_api()
|
||||
|
||||
# Use exactly like the local model!
|
||||
results = extractor.extract_entities(
|
||||
"Apple CEO Tim Cook announced the iPhone 15 in Cupertino.",
|
||||
["company", "person", "product", "location"]
|
||||
)
|
||||
print(results)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'company': ['Apple'],
|
||||
# 'person': ['Tim Cook'],
|
||||
# 'product': ['iPhone 15'],
|
||||
# 'location': ['Cupertino']
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
## Entity Extraction
|
||||
|
||||
### Simple Extraction
|
||||
|
||||
```python
|
||||
extractor = GLiNER2.from_api()
|
||||
|
||||
text = "Elon Musk founded SpaceX in 2002 and Tesla in 2003."
|
||||
results = extractor.extract_entities(
|
||||
text,
|
||||
["person", "company", "date"]
|
||||
)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'person': ['Elon Musk'],
|
||||
# 'company': ['SpaceX', 'Tesla'],
|
||||
# 'date': ['2002', '2003']
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
### With Confidence Scores and Character Positions
|
||||
|
||||
You can include confidence scores and character-level start/end positions using `include_confidence` and `include_spans`:
|
||||
|
||||
```python
|
||||
# With confidence only
|
||||
results = extractor.extract_entities(
|
||||
"Microsoft acquired LinkedIn for $26.2 billion.",
|
||||
["company", "price"],
|
||||
include_confidence=True
|
||||
)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'company': [
|
||||
# {'text': 'Microsoft', 'confidence': 0.98},
|
||||
# {'text': 'LinkedIn', 'confidence': 0.97}
|
||||
# ],
|
||||
# 'price': [
|
||||
# {'text': '$26.2 billion', 'confidence': 0.95}
|
||||
# ]
|
||||
# }
|
||||
# }
|
||||
|
||||
# With character positions (spans) only
|
||||
results = extractor.extract_entities(
|
||||
"Microsoft acquired LinkedIn.",
|
||||
["company"],
|
||||
include_spans=True
|
||||
)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'company': [
|
||||
# {'text': 'Microsoft', 'start': 0, 'end': 9},
|
||||
# {'text': 'LinkedIn', 'start': 18, 'end': 26}
|
||||
# ]
|
||||
# }
|
||||
# }
|
||||
|
||||
# With both confidence and spans
|
||||
results = extractor.extract_entities(
|
||||
"Microsoft acquired LinkedIn for $26.2 billion.",
|
||||
["company", "price"],
|
||||
include_confidence=True,
|
||||
include_spans=True
|
||||
)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'company': [
|
||||
# {'text': 'Microsoft', 'confidence': 0.98, 'start': 0, 'end': 9},
|
||||
# {'text': 'LinkedIn', 'confidence': 0.97, 'start': 18, 'end': 26}
|
||||
# ],
|
||||
# 'price': [
|
||||
# {'text': '$26.2 billion', 'confidence': 0.95, 'start': 32, 'end': 45}
|
||||
# ]
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
### Custom Threshold
|
||||
|
||||
```python
|
||||
# Only return high-confidence extractions
|
||||
results = extractor.extract_entities(
|
||||
text,
|
||||
["person", "company"],
|
||||
threshold=0.8 # Minimum 80% confidence
|
||||
)
|
||||
```
|
||||
|
||||
## Text Classification
|
||||
|
||||
### Single-Label Classification
|
||||
|
||||
```python
|
||||
extractor = GLiNER2.from_api()
|
||||
|
||||
text = "I absolutely love this product! It exceeded all my expectations."
|
||||
results = extractor.classify_text(
|
||||
text,
|
||||
{"sentiment": ["positive", "negative", "neutral"]}
|
||||
)
|
||||
# Output: {'sentiment': {'category': 'positive'}}
|
||||
```
|
||||
|
||||
### Multi-Task Classification
|
||||
|
||||
```python
|
||||
text = "Breaking: Major earthquake hits coastal city. Rescue teams deployed."
|
||||
results = extractor.classify_text(
|
||||
text,
|
||||
{
|
||||
"category": ["politics", "sports", "technology", "disaster", "business"],
|
||||
"urgency": ["low", "medium", "high"]
|
||||
}
|
||||
)
|
||||
# Output: {'category': 'disaster', 'urgency': 'high'}
|
||||
```
|
||||
|
||||
## Structured Extraction
|
||||
|
||||
### Contact Information
|
||||
|
||||
```python
|
||||
extractor = GLiNER2.from_api()
|
||||
|
||||
text = """
|
||||
Contact John Smith at john.smith@email.com or call +1-555-123-4567.
|
||||
He works as a Senior Engineer at TechCorp Inc.
|
||||
"""
|
||||
|
||||
results = extractor.extract_json(
|
||||
text,
|
||||
{
|
||||
"contact": [
|
||||
"name::str::Full name of the person",
|
||||
"email::str::Email address",
|
||||
"phone::str::Phone number",
|
||||
"job_title::str::Professional title",
|
||||
"company::str::Company name"
|
||||
]
|
||||
}
|
||||
)
|
||||
# Output: {
|
||||
# 'contact': [{
|
||||
# 'name': 'John Smith',
|
||||
# 'email': 'john.smith@email.com',
|
||||
# 'phone': '+1-555-123-4567',
|
||||
# 'job_title': 'Senior Engineer',
|
||||
# 'company': 'TechCorp Inc.'
|
||||
# }]
|
||||
# }
|
||||
```
|
||||
|
||||
### Product Information
|
||||
|
||||
```python
|
||||
text = "iPhone 15 Pro Max - $1199, 256GB storage, Natural Titanium color"
|
||||
|
||||
results = extractor.extract_json(
|
||||
text,
|
||||
{
|
||||
"product": [
|
||||
"name::str",
|
||||
"price::str",
|
||||
"storage::str",
|
||||
"color::str"
|
||||
]
|
||||
}
|
||||
)
|
||||
# Output: {
|
||||
# 'product': [{
|
||||
# 'name': 'iPhone 15 Pro Max',
|
||||
# 'price': '$1199',
|
||||
# 'storage': '256GB',
|
||||
# 'color': 'Natural Titanium'
|
||||
# }]
|
||||
# }
|
||||
```
|
||||
|
||||
## Relation Extraction
|
||||
|
||||
Extract relationships between entities as directional tuples (source, target).
|
||||
|
||||
### Basic Relation Extraction
|
||||
|
||||
```python
|
||||
extractor = GLiNER2.from_api()
|
||||
|
||||
text = "John works for Apple Inc. and lives in San Francisco. Apple Inc. is located in Cupertino."
|
||||
results = extractor.extract_relations(
|
||||
text,
|
||||
["works_for", "lives_in", "located_in"]
|
||||
)
|
||||
# Output: {
|
||||
# 'relation_extraction': {
|
||||
# 'works_for': [('John', 'Apple Inc.')],
|
||||
# 'lives_in': [('John', 'San Francisco')],
|
||||
# 'located_in': [('Apple Inc.', 'Cupertino')]
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
### With Descriptions
|
||||
|
||||
```python
|
||||
text = "Elon Musk founded SpaceX in 2002. SpaceX is located in Hawthorne, California."
|
||||
|
||||
schema = extractor.create_schema().relations({
|
||||
"founded": "Founding relationship where person created organization",
|
||||
"located_in": "Geographic relationship where entity is in a location"
|
||||
})
|
||||
|
||||
results = extractor.extract(text, schema)
|
||||
# Output: {
|
||||
# 'relation_extraction': {
|
||||
# 'founded': [('Elon Musk', 'SpaceX')],
|
||||
# 'located_in': [('SpaceX', 'Hawthorne, California')]
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
### Batch Relation Extraction
|
||||
|
||||
```python
|
||||
texts = [
|
||||
"John works for Microsoft and lives in Seattle.",
|
||||
"Sarah founded TechStartup in 2020.",
|
||||
"Bob reports to Alice at Google."
|
||||
]
|
||||
|
||||
results = extractor.batch_extract_relations(
|
||||
texts,
|
||||
["works_for", "founded", "reports_to", "lives_in"]
|
||||
)
|
||||
# Returns list of relation extraction results for each text
|
||||
```
|
||||
|
||||
## Combined Schemas
|
||||
|
||||
Combine entities, classification, relations, and structured extraction in a single call.
|
||||
|
||||
```python
|
||||
extractor = GLiNER2.from_api()
|
||||
|
||||
text = """
|
||||
Tech Review: The new MacBook Pro M3 is absolutely fantastic! Apple has outdone themselves.
|
||||
I tested it in San Francisco last week. Tim Cook works for Apple, which is located in Cupertino.
|
||||
Highly recommended for developers. Rating: 5 out of 5 stars.
|
||||
"""
|
||||
|
||||
schema = (extractor.create_schema()
|
||||
.entities(["company", "product", "location", "person"])
|
||||
.classification("sentiment", ["positive", "negative", "neutral"])
|
||||
.relations(["works_for", "located_in"])
|
||||
.structure("review")
|
||||
.field("product_name", dtype="str")
|
||||
.field("rating", dtype="str")
|
||||
.field("recommendation", dtype="str")
|
||||
)
|
||||
|
||||
results = extractor.extract(text, schema)
|
||||
# Output: {
|
||||
# 'entities': {
|
||||
# 'company': ['Apple'],
|
||||
# 'product': ['MacBook Pro M3'],
|
||||
# 'location': ['San Francisco', 'Cupertino'],
|
||||
# 'person': ['Tim Cook']
|
||||
# },
|
||||
# 'sentiment': 'positive',
|
||||
# 'relation_extraction': {
|
||||
# 'works_for': [('Tim Cook', 'Apple')],
|
||||
# 'located_in': [('Apple', 'Cupertino')]
|
||||
# },
|
||||
# 'review': [{
|
||||
# 'product_name': 'MacBook Pro M3',
|
||||
# 'rating': '5 out of 5 stars',
|
||||
# 'recommendation': 'Highly recommended for developers'
|
||||
# }]
|
||||
# }
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
Process multiple texts efficiently in a single API call.
|
||||
|
||||
```python
|
||||
extractor = GLiNER2.from_api()
|
||||
|
||||
texts = [
|
||||
"Google's Sundar Pichai unveiled Gemini AI in Mountain View.",
|
||||
"Microsoft CEO Satya Nadella announced Copilot at Build 2023.",
|
||||
"Amazon's Andy Jassy revealed new AWS services in Seattle."
|
||||
]
|
||||
|
||||
results = extractor.batch_extract_entities(
|
||||
texts,
|
||||
["company", "person", "product", "location"]
|
||||
)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
print(f"Text {i+1}: {result}")
|
||||
```
|
||||
|
||||
## Confidence Scores and Character Positions
|
||||
|
||||
### Entity Extraction with Confidence
|
||||
|
||||
```python
|
||||
# Include confidence scores
|
||||
results = extractor.extract_entities(
|
||||
"Apple released the iPhone 15 in September 2023.",
|
||||
["company", "product", "date"],
|
||||
include_confidence=True
|
||||
)
|
||||
# Each entity includes: {'text': '...', 'confidence': 0.95}
|
||||
```
|
||||
|
||||
### Entity Extraction with Character Positions
|
||||
|
||||
```python
|
||||
# Include character-level start/end positions
|
||||
results = extractor.extract_entities(
|
||||
"Apple released the iPhone 15.",
|
||||
["company", "product"],
|
||||
include_spans=True
|
||||
)
|
||||
# Each entity includes: {'text': '...', 'start': 0, 'end': 5}
|
||||
```
|
||||
|
||||
### Both Confidence and Positions
|
||||
|
||||
```python
|
||||
# Include both confidence and character positions
|
||||
results = extractor.extract_entities(
|
||||
"Apple released the iPhone 15 in September 2023.",
|
||||
["company", "product", "date"],
|
||||
include_confidence=True,
|
||||
include_spans=True
|
||||
)
|
||||
# Each entity includes: {'text': '...', 'confidence': 0.95, 'start': 0, 'end': 5}
|
||||
```
|
||||
|
||||
### Raw Results (Advanced)
|
||||
|
||||
For full control over the extraction data:
|
||||
|
||||
```python
|
||||
results = extractor.extract_entities(
|
||||
"Apple CEO Tim Cook announced new products.",
|
||||
["company", "person"],
|
||||
format_results=False, # Get raw extraction data
|
||||
include_confidence=True,
|
||||
include_spans=True
|
||||
)
|
||||
# Returns tuples: (text, confidence, start_char, end_char)
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
```python
|
||||
from gliner2 import GLiNER2, GLiNER2APIError, AuthenticationError, ValidationError
|
||||
|
||||
try:
|
||||
extractor = GLiNER2.from_api()
|
||||
results = extractor.extract_entities(text, entity_types)
|
||||
|
||||
except AuthenticationError:
|
||||
print("Invalid API key. Check your PIONEER_API_KEY.")
|
||||
|
||||
except ValidationError as e:
|
||||
print(f"Invalid request: {e}")
|
||||
|
||||
except GLiNER2APIError as e:
|
||||
print(f"API error: {e}")
|
||||
```
|
||||
|
||||
### Connection Settings
|
||||
|
||||
```python
|
||||
extractor = GLiNER2.from_api(
|
||||
api_key="your-key",
|
||||
timeout=60.0, # Request timeout (seconds)
|
||||
max_retries=5 # Retry failed requests
|
||||
)
|
||||
```
|
||||
|
||||
## API vs Local
|
||||
|
||||
| Feature | API (`from_api()`) | Local (`from_pretrained()`) |
|
||||
|---------|-------------------|----------------------------|
|
||||
| Setup | Just API key | GPU/CPU + model download |
|
||||
| Memory | ~0 MB | 2-8 GB+ |
|
||||
| Latency | Network dependent | Faster for single texts |
|
||||
| Batch | Optimized | Optimized |
|
||||
| Cost | Per request | Free after setup |
|
||||
| Offline | ❌ | ✅ |
|
||||
| RegexValidator | ❌ | ✅ |
|
||||
|
||||
### When to Use API
|
||||
|
||||
- Production deployments without GPU
|
||||
- Serverless functions (AWS Lambda, etc.)
|
||||
- Quick prototyping
|
||||
- Low-memory environments
|
||||
- Mobile/edge applications
|
||||
|
||||
### When to Use Local
|
||||
|
||||
- High-volume processing
|
||||
- Offline requirements
|
||||
- Sensitive data (no network transfer)
|
||||
- Need for RegexValidator
|
||||
- Cost optimization at scale
|
||||
|
||||
## Seamless Switching
|
||||
|
||||
The API mirrors the local interface exactly, making switching trivial:
|
||||
|
||||
```python
|
||||
# Development: Use API for quick iteration
|
||||
extractor = GLiNER2.from_api()
|
||||
|
||||
# Production: Switch to local if needed
|
||||
# extractor = GLiNER2.from_pretrained("your-model")
|
||||
|
||||
# Same code works with both!
|
||||
results = extractor.extract_entities(text, entity_types)
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
The API currently does not support:
|
||||
|
||||
1. **RegexValidator** - Use local model for regex-based filtering
|
||||
2. **Multi-schema batch** - Different schemas per text in batch (works but slower)
|
||||
3. **Custom models** - API uses the default GLiNER2 model
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Store API key securely** - Use environment variables, not hardcoded strings
|
||||
2. **Handle errors gracefully** - Network issues can occur
|
||||
3. **Use batch processing** - More efficient than individual calls
|
||||
4. **Set appropriate timeouts** - Increase for large texts
|
||||
5. **Cache results** - Avoid redundant API calls for same content
|
||||
|
||||
630
packages/GLiNER2/tutorial/8-train_data.md
Normal file
630
packages/GLiNER2/tutorial/8-train_data.md
Normal file
@ -0,0 +1,630 @@
|
||||
# GLiNER2 Training Dataset Formats
|
||||
|
||||
GLiNER2 uses JSONL format where each line contains an `input` and `output` field (or alternatively `text` and `schema`). The `input`/`text` is the text to process, and the `output`/`schema` is the schema with labels/annotations.
|
||||
|
||||
## Quick Format Reference
|
||||
|
||||
### General Structure
|
||||
|
||||
**Primary Format**:
|
||||
```jsonl
|
||||
{"input": "text to process", "output": {"schema_definition": "with_annotations"}}
|
||||
```
|
||||
|
||||
**Alternative Format** (also supported):
|
||||
```jsonl
|
||||
{"text": "text to process", "schema": {"schema_definition": "with_annotations"}}
|
||||
```
|
||||
|
||||
Both formats are equivalent - use whichever is more convenient for your workflow.
|
||||
|
||||
### Valid Output Schema Keys
|
||||
|
||||
| Key | Type | Required | Description |
|
||||
|-----|------|----------|-------------|
|
||||
| `entities` | `dict[str, list[str]]` | No | Entity type → list of entity mentions |
|
||||
| `entity_descriptions` | `dict[str, str]` | No | Entity type → description |
|
||||
| `classifications` | `list[dict]` | No | List of classification tasks |
|
||||
| `json_structures` | `list[dict]` | No | List of structured data extractions |
|
||||
| `json_descriptions` | `dict[str, dict[str, str]]` | No | Parent → field → description |
|
||||
| `relations` | `list[dict]` | No | List of relation extractions |
|
||||
|
||||
### Classification Task Fields
|
||||
|
||||
| Field | Type | Required | Description |
|
||||
|-------|------|----------|-------------|
|
||||
| `task` | `str` | Yes | Task identifier |
|
||||
| `labels` | `list[str]` | Yes | Available label options |
|
||||
| `true_label` | `list[str]` or `str` | Yes | Correct label(s) |
|
||||
| `multi_label` | `bool` | No | Enable multi-label classification |
|
||||
| `prompt` | `str` | No | Custom prompt for the task |
|
||||
| `examples` | `list[list[str]]` or `list[tuple[str, str]]` | No | Few-shot examples as [[input, output], ...] pairs. Internally converted to list of lists. |
|
||||
| `label_descriptions` | `dict[str, str]` | No | Label → description mapping |
|
||||
|
||||
### Entity Fields Format
|
||||
|
||||
Entities use a simple dictionary where keys are entity types and values are lists of mentions:
|
||||
|
||||
| Component | Type | Required | Description |
|
||||
|-----------|------|----------|-------------|
|
||||
| Entity type (key) | `str` | Yes | Name of the entity type (e.g., "person", "location") |
|
||||
| Entity mentions (value) | `list[str]` | Yes | List of entity text spans found in input |
|
||||
|
||||
**Format**: `{"entity_type": ["mention1", "mention2", ...]}`
|
||||
|
||||
### JSON Structure Fields Format
|
||||
|
||||
Each structure is a dictionary with a parent name as key and field definitions as value:
|
||||
|
||||
| Component | Type | Required | Description |
|
||||
|-----------|------|----------|-------------|
|
||||
| Parent name (key) | `str` | Yes | Name of the structure (e.g., "product", "contact") |
|
||||
| Fields (value) | `dict` | Yes | Field name → field value mappings |
|
||||
| Field value | `str` or `list[str]` or `dict` | Yes | String, list of strings, or choice dict |
|
||||
| Choice dict | `dict` with `value` and `choices` | No | For classification-style fields |
|
||||
|
||||
**Format**: `[{"parent": {"field1": "value", "field2": ["list", "values"]}}]`
|
||||
|
||||
**Multiple Instances**: When the same parent appears multiple times, each instance is a separate dict in the list:
|
||||
```jsonl
|
||||
[{"hotel": {"name": "Hotel A", ...}}, {"hotel": {"name": "Hotel B", ...}}]
|
||||
```
|
||||
|
||||
### Relation Fields Format
|
||||
|
||||
Relations use flexible field structures - you can use ANY field names (not just "head" and "tail"):
|
||||
|
||||
| Component | Type | Required | Description |
|
||||
|-----------|------|----------|-------------|
|
||||
| Relation name (key) | `str` | Yes | Name of the relation type (e.g., "works_for") |
|
||||
| Fields (value) | `dict` | Yes | Field name → field value mappings |
|
||||
| Field value | `str` or `list[str]` | Yes | String or list of strings |
|
||||
|
||||
**Standard Format**: `[{"relation_name": {"head": "entity1", "tail": "entity2"}}]`
|
||||
|
||||
**⚠️ Critical Constraint**: For a given relation type, the **first occurrence** defines the field structure:
|
||||
- The first instance of "works_for" determines what fields ALL "works_for" instances must have
|
||||
- All subsequent instances of the same relation type must use the same field names
|
||||
- Different relation types can have different field structures
|
||||
- **This consistency is enforced during validation** - inconsistent field structures will raise a `ValidationError`
|
||||
|
||||
**Example**: If first "works_for" has `{"head": "...", "tail": "..."}`, all other "works_for" instances must also have "head" and "tail" fields.
|
||||
|
||||
**Validation**: The `TrainingDataset.validate_relation_consistency()` method checks that all relation types have consistent field structures across the entire dataset.
|
||||
|
||||
---
|
||||
|
||||
## Alternative Input Formats
|
||||
|
||||
The training data loader supports multiple input formats:
|
||||
|
||||
1. **JSONL files**: `{"input": "...", "output": {...}}` or `{"text": "...", "schema": {...}}`
|
||||
2. **Python API**: Use `InputExample` and `TrainingDataset` classes from `gliner2.training.data`
|
||||
3. **Dict lists**: List of dictionaries in the same format as JSONL
|
||||
|
||||
All formats are automatically detected and converted to the internal format. See `gliner2.training.data.DataLoader_Factory` for details.
|
||||
|
||||
---
|
||||
|
||||
## 1. Classification Tasks
|
||||
|
||||
### Basic Single-Label Classification
|
||||
|
||||
```jsonl
|
||||
{"input": "This movie is absolutely fantastic! I loved every minute of it.", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}]}}
|
||||
{"input": "The service at this restaurant was terrible and the food was cold.", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["negative"]}]}}
|
||||
{"input": "The weather today is okay, nothing special.", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["neutral"]}]}}
|
||||
```
|
||||
|
||||
### Multi-label Classification
|
||||
|
||||
```jsonl
|
||||
{"input": "This smartphone has an amazing camera but the battery life is poor.", "output": {"classifications": [{"task": "product_aspects", "labels": ["camera", "battery", "screen", "performance", "design"], "true_label": ["camera", "battery"], "multi_label": true}]}}
|
||||
{"input": "Great performance and beautiful design!", "output": {"classifications": [{"task": "product_aspects", "labels": ["camera", "battery", "screen", "performance", "design"], "true_label": ["performance", "design"], "multi_label": true}]}}
|
||||
```
|
||||
|
||||
### Classification with Label Descriptions
|
||||
|
||||
```jsonl
|
||||
{"input": "Breaking: New AI model achieves human-level performance on reasoning tasks.", "output": {"classifications": [{"task": "news_category", "labels": ["technology", "politics", "sports", "entertainment"], "true_label": ["technology"], "label_descriptions": {"technology": "Articles about computers, AI, software, and tech innovations", "politics": "Government, elections, and political news", "sports": "Athletic events, teams, and competitions", "entertainment": "Movies, music, celebrities, and entertainment news"}}]}}
|
||||
```
|
||||
|
||||
### Classification with Custom Prompts
|
||||
|
||||
```jsonl
|
||||
{"input": "The patient shows signs of improvement after treatment.", "output": {"classifications": [{"task": "medical_assessment", "labels": ["improving", "stable", "declining", "critical"], "true_label": ["improving"], "prompt": "Assess the patient's medical condition based on the clinical notes."}]}}
|
||||
```
|
||||
|
||||
### Classification with Few-Shot Examples
|
||||
|
||||
Few-shot examples are provided as a list of `[input, output]` pairs. Each example is a list/tuple with exactly 2 elements:
|
||||
|
||||
```jsonl
|
||||
{"input": "This service exceeded all my expectations!", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"], "examples": [["Great product, highly recommend!", "positive"], ["Terrible experience, very disappointed.", "negative"], ["It's okay, nothing special.", "neutral"]]}]}}
|
||||
```
|
||||
|
||||
**Format**: `"examples": [[input_text, output_label], [input_text, output_label], ...]`
|
||||
|
||||
Each example pair must have exactly 2 elements: the input text and the corresponding label.
|
||||
|
||||
### Classification with Both Examples and Descriptions
|
||||
|
||||
```jsonl
|
||||
{"input": "The algorithm demonstrates linear time complexity.", "output": {"classifications": [{"task": "complexity", "labels": ["constant", "linear", "quadratic", "exponential"], "true_label": ["linear"], "examples": [["O(1) lookup time", "constant"], ["O(n) iteration", "linear"]], "label_descriptions": {"constant": "O(1) - fixed time regardless of input size", "linear": "O(n) - time scales linearly with input", "quadratic": "O(n²) - nested iterations", "exponential": "O(2ⁿ) - recursive branching"}}]}}
|
||||
```
|
||||
|
||||
### Multiple Classification Tasks
|
||||
|
||||
```jsonl
|
||||
{"input": "Exciting new smartphone with innovative features!", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}, {"task": "category", "labels": ["technology", "sports", "politics", "entertainment"], "true_label": ["technology"]}]}}
|
||||
```
|
||||
|
||||
### true_label: String vs List Format
|
||||
|
||||
Both formats are supported - use list for consistency or string for brevity:
|
||||
|
||||
```jsonl
|
||||
{"input": "Sample text A", "output": {"classifications": [{"task": "label", "labels": ["a", "b"], "true_label": ["a"]}]}}
|
||||
{"input": "Sample text B", "output": {"classifications": [{"task": "label", "labels": ["a", "b"], "true_label": "b"}]}}
|
||||
{"input": "This is great!", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": "positive"}]}}
|
||||
```
|
||||
|
||||
**Note**:
|
||||
- String format (`"true_label": "positive"`) and list format (`"true_label": ["positive"]`) are both valid for single-label classification
|
||||
- Internally, string values are automatically converted to lists (`["positive"]`)
|
||||
- For multi-label classification, always use list format: `"true_label": ["label1", "label2"]`
|
||||
|
||||
---
|
||||
|
||||
## 2. Named Entity Recognition (NER)
|
||||
|
||||
### Basic NER
|
||||
|
||||
```jsonl
|
||||
{"input": "John Smith works at OpenAI in San Francisco and will visit London next month.", "output": {"entities": {"person": ["John Smith"], "organization": ["OpenAI"], "location": ["San Francisco", "London"]}}}
|
||||
{"input": "Apple Inc. CEO Tim Cook announced the iPhone 15 release date.", "output": {"entities": {"person": ["Tim Cook"], "organization": ["Apple Inc."], "product": ["iPhone 15"]}}}
|
||||
{"input": "The meeting on January 15, 2024 will be held at Microsoft headquarters.", "output": {"entities": {"date": ["January 15, 2024"], "organization": ["Microsoft"]}}}
|
||||
```
|
||||
|
||||
### NER with Entity Descriptions
|
||||
|
||||
```jsonl
|
||||
{"input": "Dr. Sarah Johnson prescribed Metformin 500mg twice daily for diabetes treatment.", "output": {"entities": {"person": ["Dr. Sarah Johnson"], "medication": ["Metformin"], "dosage": ["500mg"], "condition": ["diabetes"]}, "entity_descriptions": {"person": "Names of people mentioned in the text", "medication": "Names of drugs or pharmaceutical products", "dosage": "Specific amounts or dosages of medications", "condition": "Medical conditions or diseases"}}}
|
||||
```
|
||||
|
||||
### NER with Multiple Instances of Same Entity Type
|
||||
|
||||
```jsonl
|
||||
{"input": "Alice, Bob, and Charlie attended the meeting with David.", "output": {"entities": {"person": ["Alice", "Bob", "Charlie", "David"]}}}
|
||||
```
|
||||
|
||||
### NER with Empty Entity Types
|
||||
|
||||
```jsonl
|
||||
{"input": "The conference will be held next week.", "output": {"entities": {"person": [], "organization": [], "location": []}}}
|
||||
```
|
||||
|
||||
### Partial NER (Some Entity Types Present)
|
||||
|
||||
```jsonl
|
||||
{"input": "Microsoft announced new features.", "output": {"entities": {"organization": ["Microsoft"], "person": []}}}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. JSON Structure Extraction
|
||||
|
||||
### Basic Structure with String Fields
|
||||
|
||||
```jsonl
|
||||
{"input": "Contact John Doe at john.doe@email.com or call (555) 123-4567.", "output": {"json_structures": [{"contact": {"name": "John Doe", "email": "john.doe@email.com", "phone": "(555) 123-4567"}}]}}
|
||||
```
|
||||
|
||||
### Structure with List Fields
|
||||
|
||||
```jsonl
|
||||
{"input": "Product features include: wireless charging, water resistance, and face recognition.", "output": {"json_structures": [{"product": {"features": ["wireless charging", "water resistance", "face recognition"]}}]}}
|
||||
```
|
||||
|
||||
### Structure with Mixed String and List Fields
|
||||
|
||||
```jsonl
|
||||
{"input": "iPhone 15 costs $999 and comes in blue, black, and white colors.", "output": {"json_structures": [{"product": {"name": "iPhone 15", "price": "$999", "colors": ["blue", "black", "white"]}}]}}
|
||||
```
|
||||
|
||||
### Multiple Instances of Same Structure Type
|
||||
|
||||
When the **same structure type** (parent name) appears multiple times in the text, each instance is a **separate dictionary** in the `json_structures` list:
|
||||
|
||||
```jsonl
|
||||
{"input": "We have two hotels available: Hotel Paradise with 4 stars, pool, and wifi for $150/night, and Budget Inn with 2 stars and parking for $80/night.", "output": {"json_structures": [{"hotel": {"name": "Hotel Paradise", "stars": "4", "amenities": ["pool", "wifi"], "price": "$150/night"}}, {"hotel": {"name": "Budget Inn", "stars": "2", "amenities": ["parking"], "price": "$80/night"}}]}}
|
||||
```
|
||||
|
||||
**Note**: Both instances use the same parent key "hotel" but are separate objects in the list. This is how you represent multiple occurrences of the same structure type.
|
||||
|
||||
Another example with three products:
|
||||
|
||||
```jsonl
|
||||
{"input": "Available products: iPhone 15 for $999, MacBook Pro for $1999, and AirPods for $199.", "output": {"json_structures": [{"product": {"name": "iPhone 15", "price": "$999"}}, {"product": {"name": "MacBook Pro", "price": "$1999"}}, {"product": {"name": "AirPods", "price": "$199"}}]}}
|
||||
```
|
||||
|
||||
### Structure with Classification Fields (Choices)
|
||||
|
||||
```jsonl
|
||||
{"input": "Book a single room at Grand Hotel for 2 nights with breakfast included.", "output": {"json_structures": [{"booking": {"hotel": "Grand Hotel", "room_type": {"value": "single", "choices": ["single", "double", "suite"]}, "nights": "2", "meal_plan": {"value": "breakfast", "choices": ["none", "breakfast", "half-board", "full-board"]}}}]}}
|
||||
```
|
||||
|
||||
### Structure with Multiple Choice Fields
|
||||
|
||||
```jsonl
|
||||
{"input": "Order a large pepperoni pizza for delivery, extra cheese.", "output": {"json_structures": [{"order": {"size": {"value": "large", "choices": ["small", "medium", "large", "xlarge"]}, "type": {"value": "pepperoni", "choices": ["cheese", "pepperoni", "veggie", "supreme"]}, "method": {"value": "delivery", "choices": ["pickup", "delivery", "dine-in"]}, "extras": ["extra cheese"]}}]}}
|
||||
```
|
||||
|
||||
### Structure with Field Descriptions
|
||||
|
||||
```jsonl
|
||||
{"input": "Patient: Mary Wilson, Age: 45, diagnosed with hypertension, prescribed Lisinopril 10mg daily.", "output": {"json_structures": [{"medical_record": {"patient_name": "Mary Wilson", "age": "45", "diagnosis": "hypertension", "medication": "Lisinopril", "dosage": "10mg daily"}}], "json_descriptions": {"medical_record": {"patient_name": "Full name of the patient", "age": "Patient's age in years", "diagnosis": "Medical condition diagnosed", "medication": "Prescribed medication name", "dosage": "Medication dosage and frequency"}}}}
|
||||
```
|
||||
|
||||
### Structure with Null/Empty Field Values
|
||||
|
||||
```jsonl
|
||||
{"input": "Product name is Widget X. Price not available.", "output": {"json_structures": [{"product": {"name": "Widget X", "price": "", "description": ""}}]}}
|
||||
```
|
||||
|
||||
### Structure with Some Fields Missing
|
||||
|
||||
```jsonl
|
||||
{"input": "Contact Sarah at sarah@example.com", "output": {"json_structures": [{"contact": {"name": "Sarah", "email": "sarah@example.com", "phone": ""}}]}}
|
||||
```
|
||||
|
||||
### Multiple Different Structure Types
|
||||
|
||||
```jsonl
|
||||
{"input": "John Doe works at TechCorp. Product ABC costs $50 with free shipping.", "output": {"json_structures": [{"employee": {"name": "John Doe", "company": "TechCorp"}}, {"product": {"name": "ABC", "price": "$50", "shipping": "free"}}]}}
|
||||
```
|
||||
|
||||
### Structure with Only List Fields
|
||||
|
||||
```jsonl
|
||||
{"input": "Available colors: red, blue, green. Sizes: S, M, L, XL.", "output": {"json_structures": [{"options": {"colors": ["red", "blue", "green"], "sizes": ["S", "M", "L", "XL"]}}]}}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Relation Extraction
|
||||
|
||||
Relations use flexible field structures. While "head" and "tail" are common, you can use ANY field names.
|
||||
|
||||
**⚠️ Important**: The first occurrence of each relation type defines the field structure for ALL instances of that type.
|
||||
|
||||
### Basic Relation (Head and Tail)
|
||||
|
||||
```jsonl
|
||||
{"input": "Alice manages the engineering team.", "output": {"relations": [{"manages": {"head": "Alice", "tail": "engineering team"}}]}}
|
||||
{"input": "John works for Microsoft.", "output": {"relations": [{"works_for": {"head": "John", "tail": "Microsoft"}}]}}
|
||||
```
|
||||
|
||||
### Multiple Instances - Same Field Structure
|
||||
|
||||
All instances of the same relation type MUST have the same fields (determined by first occurrence):
|
||||
|
||||
```jsonl
|
||||
{"input": "Alice works for Google. Bob works for Microsoft. Charlie works for Amazon.", "output": {"relations": [{"works_for": {"head": "Alice", "tail": "Google"}}, {"works_for": {"head": "Bob", "tail": "Microsoft"}}, {"works_for": {"head": "Charlie", "tail": "Amazon"}}]}}
|
||||
```
|
||||
|
||||
**Note**: All three "works_for" instances use the same fields (head, tail) as defined by the first occurrence.
|
||||
|
||||
### Multiple Different Relation Types
|
||||
|
||||
Different relation types can have different field structures:
|
||||
|
||||
```jsonl
|
||||
{"input": "John works for Apple Inc. and lives in San Francisco. Apple Inc. is located in Cupertino.", "output": {"relations": [{"works_for": {"head": "John", "tail": "Apple Inc."}}, {"lives_in": {"head": "John", "tail": "San Francisco"}}, {"located_in": {"head": "Apple Inc.", "tail": "Cupertino"}}]}}
|
||||
```
|
||||
|
||||
**Note**: Each relation type ("works_for", "lives_in", "located_in") can independently define its own field structure.
|
||||
|
||||
### Custom Field Names (Beyond Head/Tail)
|
||||
|
||||
You can use custom field names - the first occurrence defines what fields to use:
|
||||
|
||||
```jsonl
|
||||
{"input": "Alice sent $100 to Bob. Charlie sent $50 to David.", "output": {"relations": [{"transaction": {"sender": "Alice", "recipient": "Bob", "amount": "$100"}}, {"transaction": {"sender": "Charlie", "recipient": "David", "amount": "$50"}}]}}
|
||||
```
|
||||
|
||||
**Note**: First "transaction" uses sender/recipient/amount, so all "transaction" instances must use these same fields.
|
||||
|
||||
### Relations with Additional Fields
|
||||
|
||||
```jsonl
|
||||
{"input": "John Smith is the CEO of TechCorp which is headquartered in Silicon Valley.", "output": {"relations": [{"employment": {"head": "John Smith", "tail": "TechCorp", "role": "CEO"}}, {"located_in": {"head": "TechCorp", "tail": "Silicon Valley"}}]}}
|
||||
```
|
||||
|
||||
### Relations Combined with Entities
|
||||
|
||||
```jsonl
|
||||
{"input": "Elon Musk founded SpaceX in 2002. SpaceX is located in Hawthorne.", "output": {"entities": {"person": ["Elon Musk"], "organization": ["SpaceX"], "location": ["Hawthorne"], "date": ["2002"]}, "relations": [{"founded": {"head": "Elon Musk", "tail": "SpaceX"}}, {"located_in": {"head": "SpaceX", "tail": "Hawthorne"}}]}}
|
||||
```
|
||||
|
||||
### Empty Relations (Negative Example)
|
||||
|
||||
```jsonl
|
||||
{"input": "The weather is nice today.", "output": {"relations": []}}
|
||||
```
|
||||
|
||||
### Bidirectional Relations
|
||||
|
||||
```jsonl
|
||||
{"input": "Alice and Bob are colleagues.", "output": {"relations": [{"colleague_of": {"head": "Alice", "tail": "Bob"}}, {"colleague_of": {"head": "Bob", "tail": "Alice"}}]}}
|
||||
```
|
||||
|
||||
### Field Consistency: Relations vs JSON Structures
|
||||
|
||||
**Key Difference**:
|
||||
|
||||
- **Relations**: First occurrence defines field structure for ALL instances of that relation type
|
||||
- All "works_for" relations must have same fields
|
||||
- Enforced consistency per relation type
|
||||
|
||||
- **JSON Structures**: Fields can vary between instances of the same parent type
|
||||
- Uses union of all fields across instances
|
||||
- More flexible - instances can have different subsets of fields
|
||||
|
||||
**Example - Relations (Strict Consistency)**:
|
||||
```jsonl
|
||||
{"input": "Alice works for Google. Bob works for Microsoft.", "output": {"relations": [{"works_for": {"head": "Alice", "tail": "Google"}}, {"works_for": {"head": "Bob", "tail": "Microsoft"}}]}}
|
||||
```
|
||||
✓ Valid: Both "works_for" have same fields (head, tail)
|
||||
|
||||
**Example - JSON Structures (Flexible Fields)**:
|
||||
```jsonl
|
||||
{"input": "Product A costs $10. Product B costs $20 and weighs 5kg.", "output": {"json_structures": [{"product": {"name": "A", "price": "$10"}}, {"product": {"name": "B", "price": "$20", "weight": "5kg"}}]}}
|
||||
```
|
||||
✓ Valid: Second instance has extra "weight" field - this is allowed for json_structures
|
||||
|
||||
---
|
||||
|
||||
## 5. Combined Multi-Task Examples
|
||||
|
||||
### Entities + Classifications
|
||||
|
||||
```jsonl
|
||||
{"input": "Apple Inc. announced record profits. This is great news for investors.", "output": {"entities": {"organization": ["Apple Inc."]}, "classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}]}}
|
||||
```
|
||||
|
||||
### Entities + JSON Structures
|
||||
|
||||
```jsonl
|
||||
{"input": "Contact John Doe at john@example.com. He works at TechCorp.", "output": {"entities": {"person": ["John Doe"], "organization": ["TechCorp"]}, "json_structures": [{"contact": {"name": "John Doe", "email": "john@example.com", "company": "TechCorp"}}]}}
|
||||
```
|
||||
|
||||
### Entities + Relations
|
||||
|
||||
```jsonl
|
||||
{"input": "Elon Musk founded SpaceX in 2002. SpaceX is located in Hawthorne.", "output": {"entities": {"person": ["Elon Musk"], "organization": ["SpaceX"], "location": ["Hawthorne"], "date": ["2002"]}, "relations": [{"founded": {"head": "Elon Musk", "tail": "SpaceX", "year": "2002"}}, {"located_in": {"head": "SpaceX", "tail": "Hawthorne"}}]}}
|
||||
```
|
||||
|
||||
### Classifications + JSON Structures
|
||||
|
||||
```jsonl
|
||||
{"input": "Premium subscription for $99/month includes unlimited access. Great value!", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}], "json_structures": [{"subscription": {"tier": "Premium", "price": "$99/month", "features": ["unlimited access"]}}]}}
|
||||
```
|
||||
|
||||
### Entities + Classifications + JSON Structures
|
||||
|
||||
```jsonl
|
||||
{"input": "Apple CEO Tim Cook unveiled iPhone 15 for $999. Analysts are optimistic.", "output": {"entities": {"person": ["Tim Cook"], "organization": ["Apple"], "product": ["iPhone 15"]}, "classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}], "json_structures": [{"product_announcement": {"company": "Apple", "product": "iPhone 15", "price": "$999", "presenter": "Tim Cook"}}]}}
|
||||
```
|
||||
|
||||
### Entities + Relations + Classifications
|
||||
|
||||
```jsonl
|
||||
{"input": "Sarah founded TechStart in 2020. The company is doing exceptionally well.", "output": {"entities": {"person": ["Sarah"], "organization": ["TechStart"], "date": ["2020"]}, "relations": [{"founded": {"head": "Sarah", "tail": "TechStart", "year": "2020"}}], "classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}]}}
|
||||
```
|
||||
|
||||
### All Four Tasks Combined
|
||||
|
||||
```jsonl
|
||||
{"input": "Breaking: Apple announces new iPhone 15 with improved camera. Analysts are optimistic about sales projections.", "output": {"entities": {"company": ["Apple"], "product": ["iPhone 15"]}, "classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}, {"task": "category", "labels": ["technology", "business", "sports", "entertainment"], "true_label": ["technology"]}], "json_structures": [{"news_article": {"company": "Apple", "product": "iPhone 15", "feature": "improved camera", "analyst_view": "optimistic"}}], "relations": [{"product_of": {"head": "iPhone 15", "tail": "Apple"}}]}}
|
||||
```
|
||||
|
||||
### Multi-Task with Descriptions
|
||||
|
||||
```jsonl
|
||||
{"input": "Dr. Johnson prescribed medication X for condition Y. Patient shows improvement.", "output": {"entities": {"person": ["Dr. Johnson"], "medication": ["medication X"], "condition": ["condition Y"]}, "entity_descriptions": {"person": "Healthcare provider names", "medication": "Prescribed drugs", "condition": "Medical conditions"}, "classifications": [{"task": "patient_status", "labels": ["improving", "stable", "declining"], "true_label": ["improving"], "label_descriptions": {"improving": "Patient condition getting better", "stable": "No change in condition", "declining": "Patient condition worsening"}}], "json_structures": [{"prescription": {"doctor": "Dr. Johnson", "medication": "medication X", "condition": "condition Y"}}], "json_descriptions": {"prescription": {"doctor": "Prescribing physician", "medication": "Prescribed drug name", "condition": "Diagnosed condition"}}}}
|
||||
```
|
||||
|
||||
### Partial Multi-Task (Some Tasks Empty)
|
||||
|
||||
**Note**: While you can include empty dictionaries/lists for some tasks, at least one task must have content.
|
||||
|
||||
```jsonl
|
||||
{"input": "The weather forecast predicts rain tomorrow.", "output": {"entities": {}, "classifications": [{"task": "weather", "labels": ["sunny", "rainy", "cloudy", "snowy"], "true_label": ["rainy"]}], "json_structures": []}}
|
||||
```
|
||||
|
||||
This is valid because it has a classification task. However, if all tasks were empty, it would fail validation.
|
||||
|
||||
---
|
||||
|
||||
## 6. Format Edge Cases
|
||||
|
||||
### Completely Empty Output
|
||||
|
||||
**⚠️ Note**: Examples must have at least one task (entities, classifications, structures, or relations). Completely empty outputs are not valid training examples.
|
||||
|
||||
```jsonl
|
||||
{"input": "Random text with no specific information.", "output": {"entities": {}, "classifications": [], "json_structures": [], "relations": []}}
|
||||
```
|
||||
|
||||
This format will fail validation. Each example must contain at least one annotation.
|
||||
|
||||
### Empty Entities Dictionary
|
||||
|
||||
**⚠️ Note**: While an empty entities dictionary is syntactically valid, examples must have at least one task. If you only have empty entities, add at least one other task (classification, structure, or relation).
|
||||
|
||||
```jsonl
|
||||
{"input": "The weather is nice today.", "output": {"entities": {}, "classifications": [{"task": "sentiment", "labels": ["positive", "negative"], "true_label": ["positive"]}]}}
|
||||
```
|
||||
|
||||
### Empty Classifications List
|
||||
|
||||
**⚠️ Note**: While an empty classifications list is syntactically valid, examples must have at least one task. If you only have empty classifications, add at least one other task.
|
||||
|
||||
```jsonl
|
||||
{"input": "Some generic text.", "output": {"classifications": [], "entities": {"location": ["text"]}}}
|
||||
```
|
||||
|
||||
### Very Long Label Lists
|
||||
|
||||
```jsonl
|
||||
{"input": "Sample text for many labels.", "output": {"classifications": [{"task": "topic", "labels": ["label1", "label2", "label3", "label4", "label5", "label6", "label7", "label8", "label9", "label10", "label11", "label12", "label13", "label14", "label15", "label16", "label17", "label18", "label19", "label20"], "true_label": ["label5"]}]}}
|
||||
```
|
||||
|
||||
### Very Short Text
|
||||
|
||||
```jsonl
|
||||
{"input": "Yes.", "output": {"classifications": [{"task": "response", "labels": ["yes", "no", "maybe"], "true_label": ["yes"]}]}}
|
||||
{"input": "OK", "output": {"entities": {}}}
|
||||
```
|
||||
|
||||
### Special Characters in Labels
|
||||
|
||||
```jsonl
|
||||
{"input": "The C++ programming language.", "output": {"entities": {"programming_language": ["C++"]}}}
|
||||
{"input": "Use the @ symbol for mentions.", "output": {"entities": {"symbol": ["@"]}}}
|
||||
```
|
||||
|
||||
### Special Characters in Values
|
||||
|
||||
```jsonl
|
||||
{"input": "Price is $1,299.99 (including tax).", "output": {"json_structures": [{"pricing": {"amount": "$1,299.99", "note": "(including tax)"}}]}}
|
||||
```
|
||||
|
||||
### Unicode and Non-ASCII Characters
|
||||
|
||||
```jsonl
|
||||
{"input": "Café Münchën serves crème brûlée.", "output": {"entities": {"location": ["Café Münchën"], "food": ["crème brûlée"]}}}
|
||||
{"input": "东京 Tokyo is the capital.", "output": {"entities": {"location": ["东京", "Tokyo"]}}}
|
||||
```
|
||||
|
||||
### Quotes and Escaping
|
||||
|
||||
```jsonl
|
||||
{"input": "He said \"hello\" to me.", "output": {"entities": {"quote": ["\"hello\""]}}}
|
||||
```
|
||||
|
||||
### Newlines in Text
|
||||
|
||||
```jsonl
|
||||
{"input": "First line.\nSecond line.", "output": {"entities": {"text": ["First line", "Second line"]}}}
|
||||
```
|
||||
|
||||
### Numbers as Strings vs Entity Names
|
||||
|
||||
```jsonl
|
||||
{"input": "Room 123 on floor 4.", "output": {"json_structures": [{"location": {"room": "123", "floor": "4"}}]}}
|
||||
```
|
||||
|
||||
### Boolean-like Values
|
||||
|
||||
```jsonl
|
||||
{"input": "Status is active, notifications enabled.", "output": {"json_structures": [{"settings": {"status": "active", "notifications": "enabled"}}]}}
|
||||
```
|
||||
|
||||
### Empty String Values
|
||||
|
||||
```jsonl
|
||||
{"input": "Name: John, Age: unknown", "output": {"json_structures": [{"person": {"name": "John", "age": ""}}]}}
|
||||
```
|
||||
|
||||
### Multiple Empty Lines in JSONL
|
||||
|
||||
```jsonl
|
||||
{"input": "First example.", "output": {"entities": {"type": ["example"]}}}
|
||||
{"input": "Second example.", "output": {"entities": {"type": ["example"]}}}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Schema Component Reference
|
||||
|
||||
### entities
|
||||
- **Type**: `dict[str, list[str]]`
|
||||
- **Format**: `{"entity_type": ["mention1", "mention2", ...]}`
|
||||
- **Example**: `{"person": ["John", "Alice"], "location": ["NYC"]}`
|
||||
|
||||
### entity_descriptions
|
||||
- **Type**: `dict[str, str]`
|
||||
- **Format**: `{"entity_type": "description text"}`
|
||||
- **Example**: `{"person": "Names of people", "location": "Geographic places"}`
|
||||
|
||||
### classifications
|
||||
- **Type**: `list[dict]`
|
||||
- **Required fields**: `task`, `labels`, `true_label`
|
||||
- **Optional fields**: `multi_label`, `prompt`, `examples`, `label_descriptions`
|
||||
- **Example**: `[{"task": "sentiment", "labels": ["pos", "neg"], "true_label": ["pos"]}]`
|
||||
|
||||
### json_structures
|
||||
- **Type**: `list[dict]`
|
||||
- **Single instance**: `[{"parent_name": {"field1": "value1", "field2": ["list", "values"]}}]`
|
||||
- **Multiple instances (same parent)**: `[{"parent": {...}}, {"parent": {...}}]` - Same parent key, separate dicts
|
||||
- **Multiple types**: `[{"parent1": {...}}, {"parent2": {...}}]` - Different parent keys
|
||||
- **Choice format**: `{"field": {"value": "selected", "choices": ["opt1", "opt2"]}}`
|
||||
- **Example**: `[{"product": {"name": "Item", "price": "$10"}}, {"product": {"name": "Item2", "price": "$20"}}]`
|
||||
|
||||
### json_descriptions
|
||||
- **Type**: `dict[str, dict[str, str]]`
|
||||
- **Format**: `{"parent": {"field": "description"}}`
|
||||
- **Example**: `{"product": {"name": "Product name", "price": "Cost in USD"}}`
|
||||
|
||||
### relations
|
||||
- **Type**: `list[dict]`
|
||||
- **Standard format**: `[{"relation_name": {"head": "entity1", "tail": "entity2"}}]`
|
||||
- **With custom fields**: `[{"relation_name": {"sender": "A", "recipient": "B", "amount": "$100"}}]`
|
||||
- **Example**: `[{"works_for": {"head": "John", "tail": "Company"}}, {"founded": {"head": "Alice", "tail": "StartupX"}}]`
|
||||
- **⚠️ Field constraint**: First occurrence of each relation type defines field structure for ALL instances of that type
|
||||
- **Note**: While "head" and "tail" are common, you can use ANY field names - just keep them consistent per relation type
|
||||
|
||||
---
|
||||
|
||||
## Tips for Dataset Creation
|
||||
|
||||
1. **Use diverse examples** to improve model generalization
|
||||
2. **Include edge cases** - but remember each example must have at least one task
|
||||
3. **Provide descriptions** when possible to improve accuracy
|
||||
4. **Balance your classes** in classification tasks
|
||||
5. **Use realistic text** that matches your target domain
|
||||
6. **Include multiple instances** for JSON structures when applicable
|
||||
7. **For negative examples**, include at least one task (e.g., empty entities but a classification, or empty classifications but entities)
|
||||
8. **Mix task types** to train multi-task capabilities
|
||||
9. **Use consistent formatting** for similar examples
|
||||
10. **Include special characters** to ensure robust handling
|
||||
11. **Validate your dataset** using `TrainingDataset.validate(strict=True)` to catch annotation errors early
|
||||
12. **Check relation consistency** using `validate_relation_consistency()` to ensure all relation types have consistent field structures
|
||||
|
||||
## Validation Checklist
|
||||
|
||||
Make sure your JSONL file is valid by checking:
|
||||
- [ ] Each line is valid JSON
|
||||
- [ ] Required fields (`input`/`output` or `text`/`schema`) are present
|
||||
- [ ] **At least one task is present** (entities, classifications, structures, or relations)
|
||||
- [ ] Schema structure matches the expected format
|
||||
- [ ] Entity spans exist in the input text (entities can be found in the input) - checked in strict validation mode
|
||||
- [ ] Classification labels are from the defined label set
|
||||
- [ ] `true_label` is a list or string (string format is converted to list internally)
|
||||
- [ ] For multi-label classification, `multi_label` is set to `true` when multiple labels are provided
|
||||
- [ ] JSON structure fields match between instances of the same parent (flexible - union of fields is used)
|
||||
- [ ] **Relation field consistency**: All instances of the same relation type use the same field names (determined by first occurrence)
|
||||
- [ ] No trailing commas in JSON objects
|
||||
- [ ] Special characters are properly escaped
|
||||
- [ ] File encoding is UTF-8
|
||||
|
||||
### Validation Modes
|
||||
|
||||
The implementation supports two validation modes:
|
||||
|
||||
- **Standard validation**: Checks format correctness, required fields, label consistency
|
||||
- **Strict validation**: Additionally checks that entity mentions and relation values exist in the input text (case-insensitive substring matching)
|
||||
|
||||
Use strict validation during dataset creation to catch annotation errors early.
|
||||
1296
packages/GLiNER2/tutorial/9-training.md
Normal file
1296
packages/GLiNER2/tutorial/9-training.md
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user