From 9b32c8dd296920c76476e65378b16a2a854ff011 Mon Sep 17 00:00:00 2001 From: Babayaga Date: Fri, 6 Mar 2026 12:59:32 +0100 Subject: [PATCH] gliner2 --- packages/GLiNER2/.gitignore | 97 ++ packages/GLiNER2/LICENSE | 201 +++ packages/GLiNER2/README.md | 1031 ++++++++++++ packages/GLiNER2/RELEASE.md | 79 + packages/GLiNER2/benchmark_statistical.py | 401 +++++ packages/GLiNER2/gliner2/__init__.py | 23 + packages/GLiNER2/gliner2/api_client.py | 989 +++++++++++ .../GLiNER2/gliner2/inference/__init__.py | 1 + packages/GLiNER2/gliner2/inference/engine.py | 1458 +++++++++++++++++ .../GLiNER2/gliner2/inference/schema_model.py | 191 +++ packages/GLiNER2/gliner2/layers.py | 249 +++ packages/GLiNER2/gliner2/model.py | 692 ++++++++ packages/GLiNER2/gliner2/old_trainer.py | 322 ++++ packages/GLiNER2/gliner2/processor.py | 1072 ++++++++++++ packages/GLiNER2/gliner2/training/__init__.py | 0 packages/GLiNER2/gliner2/training/data.py | 1277 +++++++++++++++ packages/GLiNER2/gliner2/training/lora.py | 836 ++++++++++ packages/GLiNER2/gliner2/training/trainer.py | 1409 ++++++++++++++++ packages/GLiNER2/pyproject.toml | 25 + packages/GLiNER2/tutorial/1-classification.md | 663 ++++++++ packages/GLiNER2/tutorial/10-lora_adapters.md | 973 +++++++++++ .../GLiNER2/tutorial/11-adapter_switching.md | 201 +++ packages/GLiNER2/tutorial/2-ner.md | 372 +++++ .../GLiNER2/tutorial/3-json_extraction.md | 504 ++++++ packages/GLiNER2/tutorial/4-combined.md | 357 ++++ packages/GLiNER2/tutorial/5-validator.md | 112 ++ .../GLiNER2/tutorial/6-relation_extraction.md | 643 ++++++++ packages/GLiNER2/tutorial/7-api.md | 514 ++++++ packages/GLiNER2/tutorial/8-train_data.md | 630 +++++++ packages/GLiNER2/tutorial/9-training.md | 1296 +++++++++++++++ 30 files changed, 16618 insertions(+) create mode 100644 packages/GLiNER2/.gitignore create mode 100644 packages/GLiNER2/LICENSE create mode 100644 packages/GLiNER2/README.md create mode 100644 packages/GLiNER2/RELEASE.md create mode 100644 packages/GLiNER2/benchmark_statistical.py create mode 100644 packages/GLiNER2/gliner2/__init__.py create mode 100644 packages/GLiNER2/gliner2/api_client.py create mode 100644 packages/GLiNER2/gliner2/inference/__init__.py create mode 100644 packages/GLiNER2/gliner2/inference/engine.py create mode 100644 packages/GLiNER2/gliner2/inference/schema_model.py create mode 100644 packages/GLiNER2/gliner2/layers.py create mode 100644 packages/GLiNER2/gliner2/model.py create mode 100644 packages/GLiNER2/gliner2/old_trainer.py create mode 100644 packages/GLiNER2/gliner2/processor.py create mode 100644 packages/GLiNER2/gliner2/training/__init__.py create mode 100644 packages/GLiNER2/gliner2/training/data.py create mode 100644 packages/GLiNER2/gliner2/training/lora.py create mode 100644 packages/GLiNER2/gliner2/training/trainer.py create mode 100644 packages/GLiNER2/pyproject.toml create mode 100644 packages/GLiNER2/tutorial/1-classification.md create mode 100644 packages/GLiNER2/tutorial/10-lora_adapters.md create mode 100644 packages/GLiNER2/tutorial/11-adapter_switching.md create mode 100644 packages/GLiNER2/tutorial/2-ner.md create mode 100644 packages/GLiNER2/tutorial/3-json_extraction.md create mode 100644 packages/GLiNER2/tutorial/4-combined.md create mode 100644 packages/GLiNER2/tutorial/5-validator.md create mode 100644 packages/GLiNER2/tutorial/6-relation_extraction.md create mode 100644 packages/GLiNER2/tutorial/7-api.md create mode 100644 packages/GLiNER2/tutorial/8-train_data.md create mode 100644 packages/GLiNER2/tutorial/9-training.md diff --git a/packages/GLiNER2/.gitignore b/packages/GLiNER2/.gitignore new file mode 100644 index 0000000..1cb19c1 --- /dev/null +++ b/packages/GLiNER2/.gitignore @@ -0,0 +1,97 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Ruff +.ruff_cache/ + +# IDEs +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# OS files +.DS_Store +Thumbs.db + +# Model files (typically large) +*.pt +*.pth +*.bin +*.onnx +*.safetensors + +# Logs +*.log + +test_api_client.py \ No newline at end of file diff --git a/packages/GLiNER2/LICENSE b/packages/GLiNER2/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/packages/GLiNER2/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/packages/GLiNER2/README.md b/packages/GLiNER2/README.md new file mode 100644 index 0000000..5798893 --- /dev/null +++ b/packages/GLiNER2/README.md @@ -0,0 +1,1031 @@ +# GLiNER2: Unified Schema-Based Information Extraction + +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) +[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/) +[![PyPI version](https://badge.fury.io/py/gliner2.svg)](https://badge.fury.io/py/gliner2) +[![Downloads](https://pepy.tech/badge/gliner2)](https://pepy.tech/project/gliner2) + +> *Extract entities, classify text, parse structured data, and extract relationsβ€”all in one efficient model.* + +GLiNER2 unifies **Named Entity Recognition**, **Text Classification**, **Structured Data Extraction**, and **Relation Extraction** into a single 205M parameter model. It provides efficient CPU-based inference without requiring complex pipelines or external API dependencies. + +## ✨ Why GLiNER2? + +- **🎯 One Model, Four Tasks**: Entities, classification, structured data, and relations in a single forward pass +- **πŸ’» CPU First**: Lightning-fast inference on standard hardwareβ€”no GPU required +- **πŸ›‘οΈ Privacy**: 100% local processing, zero external dependencies + +## πŸš€ Installation & Quick Start + +```bash +pip install gliner2 +``` + +```python +from gliner2 import GLiNER2 + +# Load model once, use everywhere +extractor = GLiNER2.from_pretrained("fastino/gliner2-base-v1") + +# Extract entities in one line +text = "Apple CEO Tim Cook announced iPhone 15 in Cupertino yesterday." +result = extractor.extract_entities(text, ["company", "person", "product", "location"]) + +print(result) +# {'entities': {'company': ['Apple'], 'person': ['Tim Cook'], 'product': ['iPhone 15'], 'location': ['Cupertino']}} +``` + +### 🌐 API Access: GLiNER XL 1B + +Our biggest and most powerful modelβ€”**GLiNER XL 1B**β€”is available exclusively via API. No GPU required, no model downloads, just instant access to state-of-the-art extraction. Get your API key at [gliner.pioneer.ai](https://gliner.pioneer.ai). + +```python +from gliner2 import GLiNER2 + +# Access GLiNER XL 1B via API +extractor = GLiNER2.from_api() # Uses PIONEER_API_KEY env variable + +result = extractor.extract_entities( + "OpenAI CEO Sam Altman announced GPT-5 at their San Francisco headquarters.", + ["company", "person", "product", "location"] +) +# {'entities': {'company': ['OpenAI'], 'person': ['Sam Altman'], 'product': ['GPT-5'], 'location': ['San Francisco']}} +``` + +## πŸ“¦ Available Models + +| Model | Parameters | Description | Use Case | +|-------|------------|-------------|--------------------------------------------------| +| `fastino/gliner2-base-v1` | 205M | base size | Extraction / classification | +| `fastino/gliner2-large-v1` | 340M | large size | Extraction / classification | + +The models are available on [Hugging Face](https://huggingface.co/collections/fastino/gliner2-family). + +## πŸ“š Documentation & Tutorials + +Comprehensive guides for all GLiNER2 features: + +### Core Features +- **[Text Classification](tutorial/1-classification.md)** - Single and multi-label classification with confidence scores +- **[Entity Extraction](tutorial/2-ner.md)** - Named entity recognition with descriptions and spans +- **[Structured Data Extraction](tutorial/3-json_extraction.md)** - Parse complex JSON structures from text +- **[Combined Schemas](tutorial/4-combined.md)** - Multi-task extraction in a single pass +- **[Regex Validators](tutorial/5-validator.md)** - Filter and validate extracted spans +- **[Relation Extraction](tutorial/6-relation_extraction.md)** - Extract relationships between entities +- **[API Access](tutorial/7-api.md)** - Use GLiNER2 via cloud API + +### Training & Customization +- **[Training Data Format](tutorial/8-train_data.md)** - Complete guide to preparing training data +- **[Model Training](tutorial/9-training.md)** - Train custom models for your domain +- **[LoRA Adapters](tutorial/10-lora_adapters.md)** - Parameter-efficient fine-tuning +- **[Adapter Switching](tutorial/11-adapter_switching.md)** - Switch between domain adapters + +## 🎯 Core Capabilities + +### 1. Entity Extraction +Extract named entities with optional descriptions for precision: + +```python +# Basic entity extraction +entities = extractor.extract_entities( + "Patient received 400mg ibuprofen for severe headache at 2 PM.", + ["medication", "dosage", "symptom", "time"] +) +# Output: {'entities': {'medication': ['ibuprofen'], 'dosage': ['400mg'], 'symptom': ['severe headache'], 'time': ['2 PM']}} + +# Enhanced with descriptions for medical accuracy +entities = extractor.extract_entities( + "Patient received 400mg ibuprofen for severe headache at 2 PM.", + { + "medication": "Names of drugs, medications, or pharmaceutical substances", + "dosage": "Specific amounts like '400mg', '2 tablets', or '5ml'", + "symptom": "Medical symptoms, conditions, or patient complaints", + "time": "Time references like '2 PM', 'morning', or 'after lunch'" + } +) +# Same output but with higher accuracy due to context descriptions + +# With confidence scores +entities = extractor.extract_entities( + "Apple Inc. CEO Tim Cook announced iPhone 15 in Cupertino.", + ["company", "person", "product", "location"], + include_confidence=True +) +# Output: { +# 'entities': { +# 'company': [{'text': 'Apple Inc.', 'confidence': 0.95}], +# 'person': [{'text': 'Tim Cook', 'confidence': 0.92}], +# 'product': [{'text': 'iPhone 15', 'confidence': 0.88}], +# 'location': [{'text': 'Cupertino', 'confidence': 0.90}] +# } +# } + +# With character positions (spans) +entities = extractor.extract_entities( + "Apple Inc. CEO Tim Cook announced iPhone 15 in Cupertino.", + ["company", "person", "product"], + include_spans=True +) +# Output: { +# 'entities': { +# 'company': [{'text': 'Apple Inc.', 'start': 0, 'end': 9}], +# 'person': [{'text': 'Tim Cook', 'start': 15, 'end': 23}], +# 'product': [{'text': 'iPhone 15', 'start': 35, 'end': 44}] +# } +# } + +# With both confidence and spans +entities = extractor.extract_entities( + "Apple Inc. CEO Tim Cook announced iPhone 15 in Cupertino.", + ["company", "person", "product"], + include_confidence=True, + include_spans=True +) +# Output: { +# 'entities': { +# 'company': [{'text': 'Apple Inc.', 'confidence': 0.95, 'start': 0, 'end': 9}], +# 'person': [{'text': 'Tim Cook', 'confidence': 0.92, 'start': 15, 'end': 23}], +# 'product': [{'text': 'iPhone 15', 'confidence': 0.88, 'start': 35, 'end': 44}] +# } +# } +``` + +### 2. Text Classification +Single or multi-label classification with configurable confidence: + +```python +# Sentiment analysis +result = extractor.classify_text( + "This laptop has amazing performance but terrible battery life!", + {"sentiment": ["positive", "negative", "neutral"]} +) +# Output: {'sentiment': 'negative'} + +# Multi-aspect classification +result = extractor.classify_text( + "Great camera quality, decent performance, but poor battery life.", + { + "aspects": { + "labels": ["camera", "performance", "battery", "display", "price"], + "multi_label": True, + "cls_threshold": 0.4 + } + } +) +# Output: {'aspects': ['camera', 'performance', 'battery']} + +# With confidence scores +result = extractor.classify_text( + "This laptop has amazing performance but terrible battery life!", + {"sentiment": ["positive", "negative", "neutral"]}, + include_confidence=True +) +# Output: {'sentiment': {'label': 'negative', 'confidence': 0.82}} + +# Multi-label with confidence +schema = extractor.create_schema().classification( + "topics", + ["technology", "business", "health", "politics", "sports"], + multi_label=True, + cls_threshold=0.3 +) +text = "Apple announced new health monitoring features in their latest smartwatch, boosting their stock price." +results = extractor.extract(text, schema, include_confidence=True) +# Output: { +# 'topics': [ +# {'label': 'technology', 'confidence': 0.92}, +# {'label': 'business', 'confidence': 0.78}, +# {'label': 'health', 'confidence': 0.65} +# ] +# } +``` + +### 3. Structured Data Extraction +Parse complex structured information with field-level control: + +```python +# Product information extraction +text = "iPhone 15 Pro Max with 256GB storage, A17 Pro chip, priced at $1199. Available in titanium and black colors." + +result = extractor.extract_json( + text, + { + "product": [ + "name::str::Full product name and model", + "storage::str::Storage capacity like 256GB or 1TB", + "processor::str::Chip or processor information", + "price::str::Product price with currency", + "colors::list::Available color options" + ] + } +) +# Output: { +# 'product': [{ +# 'name': 'iPhone 15 Pro Max', +# 'storage': '256GB', +# 'processor': 'A17 Pro chip', +# 'price': '$1199', +# 'colors': ['titanium', 'black'] +# }] +# } + +# Multiple structured entities +text = "Apple Inc. headquarters in Cupertino launched iPhone 15 for $999 and MacBook Air for $1299." + +result = extractor.extract_json( + text, + { + "company": [ + "name::str::Company name", + "location::str::Company headquarters or office location" + ], + "products": [ + "name::str::Product name and model", + "price::str::Product retail price" + ] + } +) +# Output: { +# 'company': [{'name': 'Apple Inc.', 'location': 'Cupertino'}], +# 'products': [ +# {'name': 'iPhone 15', 'price': '$999'}, +# {'name': 'MacBook Air', 'price': '$1299'} +# ] +# } + +# With confidence scores +result = extractor.extract_json( + "The MacBook Pro costs $1999 and features M3 chip, 16GB RAM, and 512GB storage.", + { + "product": [ + "name::str", + "price", + "features" + ] + }, + include_confidence=True +) +# Output: { +# 'product': [{ +# 'name': {'text': 'MacBook Pro', 'confidence': 0.95}, +# 'price': [{'text': '$1999', 'confidence': 0.92}], +# 'features': [ +# {'text': 'M3 chip', 'confidence': 0.88}, +# {'text': '16GB RAM', 'confidence': 0.90}, +# {'text': '512GB storage', 'confidence': 0.87} +# ] +# }] +# } + +# With character positions (spans) +result = extractor.extract_json( + "The MacBook Pro costs $1999 and features M3 chip.", + { + "product": [ + "name::str", + "price" + ] + }, + include_spans=True +) +# Output: { +# 'product': [{ +# 'name': {'text': 'MacBook Pro', 'start': 4, 'end': 15}, +# 'price': [{'text': '$1999', 'start': 22, 'end': 27}] +# }] +# } + +# With both confidence and spans +result = extractor.extract_json( + "The MacBook Pro costs $1999 and features M3 chip, 16GB RAM, and 512GB storage.", + { + "product": [ + "name::str", + "price", + "features" + ] + }, + include_confidence=True, + include_spans=True +) +# Output: { +# 'product': [{ +# 'name': {'text': 'MacBook Pro', 'confidence': 0.95, 'start': 4, 'end': 15}, +# 'price': [{'text': '$1999', 'confidence': 0.92, 'start': 22, 'end': 27}], +# 'features': [ +# {'text': 'M3 chip', 'confidence': 0.88, 'start': 32, 'end': 39}, +# {'text': '16GB RAM', 'confidence': 0.90, 'start': 41, 'end': 49}, +# {'text': '512GB storage', 'confidence': 0.87, 'start': 55, 'end': 68} +# ] +# }] +# } +``` + +### 4. Relation Extraction +Extract relationships between entities as directional tuples: + +```python +# Basic relation extraction +text = "John works for Apple Inc. and lives in San Francisco. Apple Inc. is located in Cupertino." + +result = extractor.extract_relations( + text, + ["works_for", "lives_in", "located_in"] +) +# Output: { +# 'relation_extraction': { +# 'works_for': [('John', 'Apple Inc.')], +# 'lives_in': [('John', 'San Francisco')], +# 'located_in': [('Apple Inc.', 'Cupertino')] +# } +# } + +# With descriptions for better accuracy +schema = extractor.create_schema().relations({ + "works_for": "Employment relationship where person works at organization", + "founded": "Founding relationship where person created organization", + "acquired": "Acquisition relationship where company bought another company", + "located_in": "Geographic relationship where entity is in a location" +}) + +text = "Elon Musk founded SpaceX in 2002. SpaceX is located in Hawthorne, California." +results = extractor.extract(text, schema) +# Output: { +# 'relation_extraction': { +# 'founded': [('Elon Musk', 'SpaceX')], +# 'located_in': [('SpaceX', 'Hawthorne, California')] +# } +# } + +# With confidence scores +results = extractor.extract_relations( + "John works for Apple Inc. and lives in San Francisco.", + ["works_for", "lives_in"], + include_confidence=True +) +# Output: { +# 'relation_extraction': { +# 'works_for': [{ +# 'head': {'text': 'John', 'confidence': 0.95}, +# 'tail': {'text': 'Apple Inc.', 'confidence': 0.92} +# }], +# 'lives_in': [{ +# 'head': {'text': 'John', 'confidence': 0.94}, +# 'tail': {'text': 'San Francisco', 'confidence': 0.91} +# }] +# } +# } + +# With character positions (spans) +results = extractor.extract_relations( + "John works for Apple Inc. and lives in San Francisco.", + ["works_for", "lives_in"], + include_spans=True +) +# Output: { +# 'relation_extraction': { +# 'works_for': [{ +# 'head': {'text': 'John', 'start': 0, 'end': 4}, +# 'tail': {'text': 'Apple Inc.', 'start': 15, 'end': 25} +# }], +# 'lives_in': [{ +# 'head': {'text': 'John', 'start': 0, 'end': 4}, +# 'tail': {'text': 'San Francisco', 'start': 33, 'end': 46} +# }] +# } +# } + +# With both confidence and spans +results = extractor.extract_relations( + "John works for Apple Inc. and lives in San Francisco.", + ["works_for", "lives_in"], + include_confidence=True, + include_spans=True +) +# Output: { +# 'relation_extraction': { +# 'works_for': [{ +# 'head': {'text': 'John', 'confidence': 0.95, 'start': 0, 'end': 4}, +# 'tail': {'text': 'Apple Inc.', 'confidence': 0.92, 'start': 15, 'end': 25} +# }], +# 'lives_in': [{ +# 'head': {'text': 'John', 'confidence': 0.94, 'start': 0, 'end': 4}, +# 'tail': {'text': 'San Francisco', 'confidence': 0.91, 'start': 33, 'end': 46} +# }] +# } +# } +``` + +### 5. Multi-Task Schema Composition +Combine all extraction types when you need comprehensive analysis: + +```python +# Use create_schema() for multi-task scenarios +schema = (extractor.create_schema() + # Extract key entities + .entities({ + "person": "Names of people, executives, or individuals", + "company": "Organization, corporation, or business names", + "product": "Products, services, or offerings mentioned" + }) + + # Classify the content + .classification("sentiment", ["positive", "negative", "neutral"]) + .classification("category", ["technology", "business", "finance", "healthcare"]) + + # Extract relationships + .relations(["works_for", "founded", "located_in"]) + + # Extract structured product details + .structure("product_info") + .field("name", dtype="str") + .field("price", dtype="str") + .field("features", dtype="list") + .field("availability", dtype="str", choices=["in_stock", "pre_order", "sold_out"]) +) + +# Comprehensive extraction in one pass +text = "Apple CEO Tim Cook unveiled the revolutionary iPhone 15 Pro for $999. The device features an A17 Pro chip and titanium design. Tim Cook works for Apple, which is located in Cupertino." + +results = extractor.extract(text, schema) +# Output: { +# 'entities': { +# 'person': ['Tim Cook'], +# 'company': ['Apple'], +# 'product': ['iPhone 15 Pro'] +# }, +# 'sentiment': 'positive', +# 'category': 'technology', +# 'relation_extraction': { +# 'works_for': [('Tim Cook', 'Apple')], +# 'located_in': [('Apple', 'Cupertino')] +# }, +# 'product_info': [{ +# 'name': 'iPhone 15 Pro', +# 'price': '$999', +# 'features': ['A17 Pro chip', 'titanium design'], +# 'availability': 'in_stock' +# }] +# } +``` + +## 🏭 Example Usage Scenarios + +### Financial Document Processing + +```python +financial_text = """ +Transaction Report: Goldman Sachs processed a $2.5M equity trade for Tesla Inc. +on March 15, 2024. Commission: $1,250. Status: Completed. +""" + +# Extract structured financial data +result = extractor.extract_json( + financial_text, + { + "transaction": [ + "broker::str::Financial institution or brokerage firm", + "amount::str::Transaction amount with currency", + "security::str::Stock, bond, or financial instrument", + "date::str::Transaction date", + "commission::str::Fees or commission charged", + "status::str::Transaction status", + "type::[equity|bond|option|future|forex]::str::Type of financial instrument" + ] + } +) +# Output: { +# 'transaction': [{ +# 'broker': 'Goldman Sachs', +# 'amount': '$2.5M', +# 'security': 'Tesla Inc.', +# 'date': 'March 15, 2024', +# 'commission': '$1,250', +# 'status': 'Completed', +# 'type': 'equity' +# }] +# } +``` + +### Healthcare Information Extraction + +```python +medical_record = """ +Patient: Sarah Johnson, 34, presented with acute chest pain and shortness of breath. +Prescribed: Lisinopril 10mg daily, Metoprolol 25mg twice daily. +Follow-up scheduled for next Tuesday. +""" + +result = extractor.extract_json( + medical_record, + { + "patient_info": [ + "name::str::Patient full name", + "age::str::Patient age", + "symptoms::list::Reported symptoms or complaints" + ], + "prescriptions": [ + "medication::str::Drug or medication name", + "dosage::str::Dosage amount and frequency", + "frequency::str::How often to take the medication" + ] + } +) +# Output: { +# 'patient_info': [{ +# 'name': 'Sarah Johnson', +# 'age': '34', +# 'symptoms': ['acute chest pain', 'shortness of breath'] +# }], +# 'prescriptions': [ +# {'medication': 'Lisinopril', 'dosage': '10mg', 'frequency': 'daily'}, +# {'medication': 'Metoprolol', 'dosage': '25mg', 'frequency': 'twice daily'} +# ] +# } +``` + +### Legal Contract Analysis + +```python +contract_text = """ +Service Agreement between TechCorp LLC and DataSystems Inc., effective January 1, 2024. +Monthly fee: $15,000. Contract term: 24 months with automatic renewal. +Termination clause: 30-day written notice required. +""" + +# Multi-task extraction for comprehensive analysis +schema = (extractor.create_schema() + .entities(["company", "date", "duration", "fee"]) + .classification("contract_type", ["service", "employment", "nda", "partnership"]) + .relations(["signed_by", "involves", "dated"]) + .structure("contract_terms") + .field("parties", dtype="list") + .field("effective_date", dtype="str") + .field("monthly_fee", dtype="str") + .field("term_length", dtype="str") + .field("renewal", dtype="str", choices=["automatic", "manual", "none"]) + .field("termination_notice", dtype="str") +) + +results = extractor.extract(contract_text, schema) +# Output: { +# 'entities': { +# 'company': ['TechCorp LLC', 'DataSystems Inc.'], +# 'date': ['January 1, 2024'], +# 'duration': ['24 months'], +# 'fee': ['$15,000'] +# }, +# 'contract_type': 'service', +# 'relation_extraction': { +# 'involves': [('TechCorp LLC', 'DataSystems Inc.')], +# 'dated': [('Service Agreement', 'January 1, 2024')] +# }, +# 'contract_terms': [{ +# 'parties': ['TechCorp LLC', 'DataSystems Inc.'], +# 'effective_date': 'January 1, 2024', +# 'monthly_fee': '$15,000', +# 'term_length': '24 months', +# 'renewal': 'automatic', +# 'termination_notice': '30-day written notice' +# }] +# } +``` + +### Knowledge Graph Construction + +```python +# Extract entities and relations for knowledge graph building +text = """ +Elon Musk founded SpaceX in 2002. SpaceX is located in Hawthorne, California. +SpaceX acquired Swarm Technologies in 2021. Many engineers work for SpaceX. +""" + +schema = (extractor.create_schema() + .entities(["person", "organization", "location", "date"]) + .relations({ + "founded": "Founding relationship where person created organization", + "acquired": "Acquisition relationship where company bought another company", + "located_in": "Geographic relationship where entity is in a location", + "works_for": "Employment relationship where person works at organization" + }) +) + +results = extractor.extract(text, schema) +# Output: { +# 'entities': { +# 'person': ['Elon Musk', 'engineers'], +# 'organization': ['SpaceX', 'Swarm Technologies'], +# 'location': ['Hawthorne, California'], +# 'date': ['2002', '2021'] +# }, +# 'relation_extraction': { +# 'founded': [('Elon Musk', 'SpaceX')], +# 'acquired': [('SpaceX', 'Swarm Technologies')], +# 'located_in': [('SpaceX', 'Hawthorne, California')], +# 'works_for': [('engineers', 'SpaceX')] +# } +# } +``` + +## βš™οΈ Advanced Configuration + +### Custom Confidence Thresholds + +```python +# High-precision extraction for critical fields +result = extractor.extract_json( + text, + { + "financial_data": [ + "account_number::str::Bank account number", # default threshold + "amount::str::Transaction amount", # default threshold + "routing_number::str::Bank routing number" # default threshold + ] + }, + threshold=0.9 # High confidence for all fields +) + +# Per-field thresholds using schema builder (for multi-task scenarios) +schema = (extractor.create_schema() + .structure("sensitive_data") + .field("ssn", dtype="str", threshold=0.95) # Highest precision + .field("email", dtype="str", threshold=0.8) # Medium precision + .field("phone", dtype="str", threshold=0.7) # Lower precision +) +``` + +### Field Types and Constraints + +```python +# Structured extraction with choices and types +result = extractor.extract_json( + "Premium subscription at $99/month with mobile and web access.", + { + "subscription": [ + "tier::[basic|premium|enterprise]::str::Subscription level", + "price::str::Monthly or annual cost", + "billing::[monthly|annual]::str::Billing frequency", + "features::[mobile|web|api|analytics]::list::Included features" + ] + } +) +# Output: { +# 'subscription': [{ +# 'tier': 'premium', +# 'price': '$99/month', +# 'billing': 'monthly', +# 'features': ['mobile', 'web'] +# }] +# } +``` + +## πŸ” Regex Validators + +Filter extracted spans to ensure they match expected patterns, improving extraction quality and reducing false positives. + +```python +from gliner2 import GLiNER2, RegexValidator + +extractor = GLiNER2.from_pretrained("fastino/gliner2-base-v1") + +# Email validation +email_validator = RegexValidator(r"^[\w\.-]+@[\w\.-]+\.\w+$") +schema = (extractor.create_schema() + .structure("contact") + .field("email", dtype="str", validators=[email_validator]) +) + +text = "Contact: john@company.com, not-an-email, jane@domain.org" +results = extractor.extract(text, schema) +# Output: {'contact': [{'email': 'john@company.com'}]} # Only valid emails + +# Phone number validation (US format) +phone_validator = RegexValidator(r"\(\d{3}\)\s\d{3}-\d{4}", mode="partial") +schema = (extractor.create_schema() + .structure("contact") + .field("phone", dtype="str", validators=[phone_validator]) +) + +text = "Call (555) 123-4567 or 5551234567" +results = extractor.extract(text, schema) +# Output: {'contact': [{'phone': '(555) 123-4567'}]} # Second number filtered out + +# URL validation +url_validator = RegexValidator(r"^https?://", mode="partial") +schema = (extractor.create_schema() + .structure("links") + .field("url", dtype="list", validators=[url_validator]) +) + +text = "Visit https://example.com or www.site.com" +results = extractor.extract(text, schema) +# Output: {'links': [{'url': ['https://example.com']}]} # www.site.com filtered out + +# Exclude test data +import re +no_test_validator = RegexValidator(r"^(test|demo|sample)", exclude=True, flags=re.IGNORECASE) +schema = (extractor.create_schema() + .structure("products") + .field("name", dtype="list", validators=[no_test_validator]) +) + +text = "Products: iPhone, Test Phone, Samsung Galaxy" +results = extractor.extract(text, schema) +# Output: {'products': [{'name': ['iPhone', 'Samsung Galaxy']}]} # Test Phone excluded + +# Multiple validators (all must pass) +username_validators = [ + RegexValidator(r"^[a-zA-Z0-9_]+$"), # Alphanumeric + underscore + RegexValidator(r"^.{3,20}$"), # 3-20 characters + RegexValidator(r"^(?!admin)", exclude=True, flags=re.IGNORECASE) # No "admin" +] + +schema = (extractor.create_schema() + .structure("user") + .field("username", dtype="str", validators=username_validators) +) + +text = "Users: ab, john_doe, user@domain, admin, valid_user123" +results = extractor.extract(text, schema) +# Output: {'user': [{'username': 'john_doe'}]} # Only valid usernames +``` + +## πŸ“¦ Batch Processing + +Process multiple texts efficiently in a single call: + +```python +# Batch entity extraction +texts = [ + "Google's Sundar Pichai unveiled Gemini AI in Mountain View.", + "Microsoft CEO Satya Nadella announced Copilot at Build 2023.", + "Amazon's Andy Jassy revealed new AWS services in Seattle." +] + +results = extractor.batch_extract_entities( + texts, + ["company", "person", "product", "location"], + batch_size=8 +) +# Returns list of results, one per input text + +# Batch relation extraction +texts = [ + "John works for Microsoft and lives in Seattle.", + "Sarah founded TechStartup in 2020.", + "Bob reports to Alice at Google." +] + +results = extractor.batch_extract_relations( + texts, + ["works_for", "founded", "reports_to", "lives_in"], + batch_size=8 +) +# Returns list of relation extraction results for each text +# All requested relation types appear in each result, even if empty + +# Batch with confidence and spans +results = extractor.batch_extract_entities( + texts, + ["company", "person"], + include_confidence=True, + include_spans=True, + batch_size=8 +) +``` + +## πŸŽ“ Training Custom Models + +Train GLiNER2 on your own data to specialize for your domain or use case. + +### Quick Start Training + +```python +from gliner2 import GLiNER2 +from gliner2.training.data import InputExample +from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig + +# 1. Prepare training data +examples = [ + InputExample( + text="John works at Google in California.", + entities={"person": ["John"], "company": ["Google"], "location": ["California"]} + ), + InputExample( + text="Apple released iPhone 15.", + entities={"company": ["Apple"], "product": ["iPhone 15"]} + ), + # Add more examples... +] + +# 2. Configure training +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") +config = TrainingConfig( + output_dir="./output", + num_epochs=10, + batch_size=8, + encoder_lr=1e-5, + task_lr=5e-4 +) + +# 3. Train +trainer = GLiNER2Trainer(model, config) +trainer.train(train_data=examples) +``` + +### Training Data Format (JSONL) + +GLiNER2 uses JSONL format where each line contains an `input` and `output` field: + +```jsonl +{"input": "Tim Cook is the CEO of Apple Inc., based in Cupertino, California.", "output": {"entities": {"person": ["Tim Cook"], "company": ["Apple Inc."], "location": ["Cupertino", "California"]}, "entity_descriptions": {"person": "Full name of a person", "company": "Business organization name", "location": "Geographic location or place"}}} +{"input": "OpenAI released GPT-4 in March 2023.", "output": {"entities": {"company": ["OpenAI"], "model": ["GPT-4"], "date": ["March 2023"]}}} +``` + +**Classification Example:** +```jsonl +{"input": "This movie is absolutely fantastic! I loved every minute of it.", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}]}} +{"input": "The service was terrible and the food was cold.", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["negative"]}]}} +``` + +**Structured Extraction Example:** +```jsonl +{"input": "iPhone 15 Pro Max with 256GB storage, priced at $1199.", "output": {"json_structures": [{"product": {"name": "iPhone 15 Pro Max", "storage": "256GB", "price": "$1199"}}]}} +``` + +**Relation Extraction Example:** +```jsonl +{"input": "John works for Apple Inc. and lives in San Francisco.", "output": {"relations": [{"works_for": {"head": "John", "tail": "Apple Inc."}}, {"lives_in": {"head": "John", "tail": "San Francisco"}}]}} +``` + +### Training from JSONL File + +```python +from gliner2 import GLiNER2 +from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig + +# Load model and train from JSONL file +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") +config = TrainingConfig(output_dir="./output", num_epochs=10) + +trainer = GLiNER2Trainer(model, config) +trainer.train(train_data="train.jsonl") # Path to your JSONL file +``` + +### LoRA Training (Parameter-Efficient Fine-Tuning) + +Train lightweight adapters for domain-specific tasks: + +```python +from gliner2 import GLiNER2 +from gliner2.training.data import InputExample +from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig + +# Prepare domain-specific data +legal_examples = [ + InputExample( + text="Apple Inc. filed a lawsuit against Samsung Electronics.", + entities={"company": ["Apple Inc.", "Samsung Electronics"]} + ), + # Add more examples... +] + +# Configure LoRA training +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") +config = TrainingConfig( + output_dir="./legal_adapter", + num_epochs=10, + batch_size=8, + encoder_lr=1e-5, + task_lr=5e-4, + + # LoRA settings + use_lora=True, # Enable LoRA + lora_r=8, # Rank (4, 8, 16, 32) + lora_alpha=16.0, # Scaling factor (usually 2*r) + lora_dropout=0.0, # Dropout for LoRA layers + save_adapter_only=True # Save only adapter (~5MB vs ~450MB) +) + +# Train adapter +trainer = GLiNER2Trainer(model, config) +trainer.train(train_data=legal_examples) + +# Use the adapter +model.load_adapter("./legal_adapter/final") +results = model.extract_entities(legal_text, ["company", "law"]) +``` + +**Benefits of LoRA:** +- **Smaller size**: Adapters are ~2-10 MB vs ~450 MB for full models +- **Faster training**: 2-3x faster than full fine-tuning +- **Easy switching**: Swap adapters in milliseconds for different domains + +### Complete Training Example + +```python +from gliner2 import GLiNER2 +from gliner2.training.data import InputExample, TrainingDataset +from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig + +# Prepare training data +train_examples = [ + InputExample( + text="Tim Cook is the CEO of Apple Inc., based in Cupertino, California.", + entities={ + "person": ["Tim Cook"], + "company": ["Apple Inc."], + "location": ["Cupertino", "California"] + }, + entity_descriptions={ + "person": "Full name of a person", + "company": "Business organization name", + "location": "Geographic location or place" + } + ), + # Add more examples... +] + +# Create and validate dataset +train_dataset = TrainingDataset(train_examples) +train_dataset.validate(strict=True, raise_on_error=True) +train_dataset.print_stats() + +# Split into train/validation +train_data, val_data, _ = train_dataset.split( + train_ratio=0.8, + val_ratio=0.2, + test_ratio=0.0, + shuffle=True, + seed=42 +) + +# Configure training +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") +config = TrainingConfig( + output_dir="./ner_model", + experiment_name="ner_training", + num_epochs=15, + batch_size=16, + encoder_lr=1e-5, + task_lr=5e-4, + warmup_ratio=0.1, + scheduler_type="cosine", + fp16=True, + eval_strategy="epoch", + save_best=True, + early_stopping=True, + early_stopping_patience=3 +) + +# Train +trainer = GLiNER2Trainer(model, config) +trainer.train(train_data=train_data, val_data=val_data) + +# Load best model +model = GLiNER2.from_pretrained("./ner_model/best") +``` + +For more details, see the [Training Tutorial](tutorial/9-training.md) and [Data Format Guide](tutorial/8-train_data.md). + +## πŸ“„ License + +This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. + +## πŸ“š Citation + +If you use GLiNER2 in your research, please cite: + +```bibtex +@inproceedings{zaratiana-etal-2025-gliner2, + title = "{GL}i{NER}2: Schema-Driven Multi-Task Learning for Structured Information Extraction", + author = "Zaratiana, Urchade and + Pasternak, Gil and + Boyd, Oliver and + Hurn-Maloney, George and + Lewis, Ash", + editor = {Habernal, Ivan and + Schulam, Peter and + Tiedemann, J{\"o}rg}, + booktitle = "Proceedings of the 2025 Conference on Empirical Methods in Natural Language Processing: System Demonstrations", + month = nov, + year = "2025", + address = "Suzhou, China", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2025.emnlp-demos.10/", + pages = "130--140", + ISBN = "979-8-89176-334-0", + abstract = "Information extraction (IE) is fundamental to numerous NLP applications, yet existing solutions often require specialized models for different tasks or rely on computationally expensive large language models. We present GLiNER2, a unified framework that enhances the original GLiNER architecture to support named entity recognition, text classification, and hierarchical structured data extraction within a single efficient model. Built on a fine-tuned encoder architecture, GLiNER2 maintains CPU efficiency and compact size while introducing multi-task composition through an intuitive schema-based interface. Our experiments demonstrate competitive performance across diverse IE tasks with substantial improvements in deployment accessibility compared to LLM-based alternatives. We release GLiNER2 as an open-source library available through pip, complete with pre-trained models and comprehensive documentation." +} +``` + +## πŸ™ Acknowledgments + +Built upon the original [GLiNER](https://github.com/urchade/GLiNER) architecture by the team at [Fastino AI](https://fastino.ai). + +--- + +
+ Ready to extract insights from your data?
+ pip install gliner2 +
diff --git a/packages/GLiNER2/RELEASE.md b/packages/GLiNER2/RELEASE.md new file mode 100644 index 0000000..2d642bb --- /dev/null +++ b/packages/GLiNER2/RELEASE.md @@ -0,0 +1,79 @@ +# PyPI Release Guide for GLiNER2 + +## Prerequisites + +- [ ] Python 3.8+ installed +- [ ] PyPI account with API token configured +- [ ] Write access to the repository + +## Release Steps + +### 1. Update Version + +Update version in `gliner2/__init__.py`: +```python +__version__ = "1.0.1" # New version +``` + +### 2. Build Package + +```bash +# Install build tools +pip install build twine + +# Clean previous builds +rm -rf dist/ build/ *.egg-info/ + +# Build package +python -m build +``` + +### 3. Test Build (Optional) + +```bash +# Test on TestPyPI first +twine upload --repository testpypi dist/* + +# Install and test +pip install --index-url https://test.pypi.org/simple/ gliner2 +``` + +### 4. Upload to PyPI + +```bash +# Upload to production PyPI +twine upload dist/* +``` + +### 5. Create GitHub Release + +1. Go to GitHub repository β†’ Releases +2. Click "Create a new release" +3. Tag: `v1.0.1` (matching version) +4. Title: `GLiNER2 v1.0.1` +5. Description: Summary of changes +6. Attach built wheels from `dist/` folder + +### 6. Verify Release + +```bash +# Install from PyPI +pip install gliner2==1.0.1 + +# Test basic functionality +python -c "from gliner2 import GLiNER2; print('βœ“ Import successful')" +``` + +## Troubleshooting + +- **Authentication error**: Configure PyPI token in `~/.pypirc` or use `--username __token__` +- **File exists error**: Version already exists on PyPI, increment version number +- **Build fails**: Check `pyproject.toml` dependencies and Python version compatibility + +## Checklist + +- [ ] Version updated in `__init__.py` +- [ ] Package builds without errors +- [ ] Uploaded to PyPI successfully +- [ ] GitHub release created +- [ ] Installation verified \ No newline at end of file diff --git a/packages/GLiNER2/benchmark_statistical.py b/packages/GLiNER2/benchmark_statistical.py new file mode 100644 index 0000000..2599a31 --- /dev/null +++ b/packages/GLiNER2/benchmark_statistical.py @@ -0,0 +1,401 @@ +""" +Statistical benchmark with confidence intervals and p-values. + +Micro-benchmarks: interleaved old/new in same process β†’ paired t-test. +End-to-end: saves raw timings to JSON for cross-process Welch's t-test. + +Usage: + # Baseline + git stash + python benchmark_statistical.py --tag baseline --n 300 + git stash pop + + # Optimized + python benchmark_statistical.py --tag optimized --n 300 + + # Compare + python benchmark_statistical.py --compare baseline optimized +""" + +import argparse +import json +import math +import random +import time +import statistics +import sys +from collections import OrderedDict + +import torch +from scipy import stats as sp_stats + + +# ─── Helpers ────────────────────────────────────────────────────── + +def sync(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def ci95(data): + """95% CI half-width using t-distribution.""" + n = len(data) + if n < 2: + return 0.0 + se = statistics.stdev(data) / math.sqrt(n) + t_crit = sp_stats.t.ppf(0.975, df=n - 1) + return t_crit * se + + +def collect(fn, n_warmup, n_iter): + """Run fn with warmup, return list of times in ms.""" + for _ in range(n_warmup): + fn() + sync() + times = [] + for _ in range(n_iter): + sync() + t0 = time.perf_counter() + fn() + sync() + times.append((time.perf_counter() - t0) * 1000) + return times + + +def paired_test(old_times, new_times): + """Paired t-test on matched samples. Returns (t_stat, p_value, mean_diff, ci95_diff).""" + diffs = [o - n for o, n in zip(old_times, new_times)] + n = len(diffs) + mean_d = statistics.mean(diffs) + se_d = statistics.stdev(diffs) / math.sqrt(n) + t_stat = mean_d / se_d if se_d > 0 else 0 + p_val = 2 * sp_stats.t.sf(abs(t_stat), df=n - 1) + hw = ci95(diffs) + return t_stat, p_val, mean_d, hw + + +def welch_test(a, b): + """Welch's t-test (unequal variance). Returns (t_stat, p_value).""" + t_stat, p_val = sp_stats.ttest_ind(a, b, equal_var=False) + return t_stat, p_val + + +def fmt_p(p): + if p < 0.001: + return f"{p:.2e}" + return f"{p:.4f}" + + +# ─── End-to-end benchmark ──────────────────────────────────────── + +def run_e2e(n_iter, n_warmup): + """Run end-to-end scenarios, return dict of {name: [times]}.""" + from gliner2 import GLiNER2 + + device = "cuda" if torch.cuda.is_available() else "cpu" + model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") + model = model.to(device) + model.eval() + + text1 = "Apple CEO Tim Cook announced the iPhone 15 launch in Cupertino on September 12, 2023." + ents = ["company", "person", "product", "location", "date"] + texts8 = [ + "Apple CEO Tim Cook announced the iPhone 15 launch in Cupertino.", + "Google's Sundar Pichai spoke at the conference in Mountain View.", + "Microsoft released Windows 11 in Redmond last year.", + "Amazon founder Jeff Bezos invested in Blue Origin in Seattle.", + "Tesla CEO Elon Musk unveiled the Cybertruck at the Fremont factory.", + "Meta's Mark Zuckerberg presented Quest 3 in Menlo Park.", + "NVIDIA's Jensen Huang showcased the H100 GPU at GTC in San Jose.", + "OpenAI CEO Sam Altman launched GPT-4 in San Francisco.", + ] + long_text = ( + "Apple Inc., headquartered in Cupertino, California, is a multinational technology company " + "founded by Steve Jobs, Steve Wozniak, and Ronald Wayne in April 1976. The company designs, " + "develops, and sells consumer electronics, computer software, and online services. Tim Cook " + "has served as CEO since August 2011. Apple's main products include the iPhone, iPad, Mac, " + "Apple Watch, and AirPods. The company also operates services including the App Store, " + "Apple Music, iCloud, and Apple TV Plus. In 2023, Apple reported annual revenue of $383 " + "billion, making it the world's largest technology company by revenue. The company employs " + "over 160,000 people worldwide." + ) + ents6 = ["company", "person", "product", "location", "date", "monetary_value"] + text_struct = "John Smith, aged 35, is a software engineer at Google in Mountain View." + schema_struct = model.create_schema() + schema_struct.structure("person").field("name").field("age").field("job_title").field("company").field("location") + text_rel = "Apple CEO Tim Cook announced the iPhone 15 launch in Cupertino on September 12." + rels = ["CEO_of", "located_in", "announced_on"] + + results = OrderedDict() + scenarios = [ + ("single_entity", lambda: model.extract_entities(text1, ents)), + ("single_structure", lambda: model.extract(text_struct, schema_struct)), + ("single_relation", lambda: model.extract_relations(text_rel, rels)), + ("batch8_entity", lambda: model.batch_extract_entities(texts8, ents, batch_size=8)), + ("long_text_entity", lambda: model.extract_entities(long_text, ents6)), + ] + + for name, fn in scenarios: + print(f" Running {name} (n={n_iter})...", end=" ", flush=True) + times = collect(fn, n_warmup, n_iter) + results[name] = times + m, hw = statistics.mean(times), ci95(times) + print(f"{m:.2f} Β± {hw:.2f} ms") + + return results + + +# ─── Micro-benchmarks (interleaved old/new) ────────────────────── + +def run_micro(n_iter, n_warmup): + """Run micro-benchmarks with interleaved old/new for paired comparison.""" + import copy + from gliner2 import GLiNER2 + from gliner2.training.trainer import ExtractorCollator + from torch.utils.data import DataLoader + + device = "cuda" if torch.cuda.is_available() else "cpu" + model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") + model = model.to(device) + model.eval() + tokenizer = model.processor.tokenizer + + results = OrderedDict() + + # --- OPT-1: Token ID lookup --- + special_set_str = {"[P]", "[C]", "[E]", "[R]", "[L]"} + special_ids = frozenset(tokenizer.convert_tokens_to_ids(t) for t in special_set_str) + dummy_ids = list(range(200)) + + def opt1_old(): + for tid in dummy_ids: + tok = tokenizer.convert_ids_to_tokens(tid) + _ = tok in special_set_str + + def opt1_new(): + for tid in dummy_ids: + _ = tid in special_ids + + print(" OPT-1 Token ID lookup...", end=" ", flush=True) + old_t, new_t = _interleaved(opt1_old, opt1_new, n_warmup, n_iter) + results["OPT-1 Token ID lookup"] = {"old": old_t, "new": new_t} + _print_paired(old_t, new_t) + + # --- OPT-3: Avoid retokenization --- + test_text = "Apple CEO Tim Cook announced the iPhone 15 launch in Cupertino on September 12." + dummy_map = list(range(15)) + + def opt3_old(): + return len(model.processor._tokenize_text(test_text)) + + def opt3_new(): + return len(dummy_map) + + print(" OPT-3 Avoid retokenization...", end=" ", flush=True) + old_t, new_t = _interleaved(opt3_old, opt3_new, n_warmup, n_iter) + results["OPT-3 Avoid retokenization"] = {"old": old_t, "new": new_t} + _print_paired(old_t, new_t) + + # --- OPT-4: Deepcopy --- + schema_dict = { + "json_structures": [{"person": {"name": "", "age": "", "job": ""}}], + "entities": {"company": "", "location": ""}, + "relations": [], "classifications": [], + } + record = {"text": "Apple CEO Tim Cook announced iPhone 15." * 3, "schema": schema_dict} + + def opt4_old(): + return copy.deepcopy(record) + + def opt4_new(): + return {"text": record["text"], "schema": copy.deepcopy(record["schema"])} + + print(" OPT-4 Deepcopy...", end=" ", flush=True) + old_t, new_t = _interleaved(opt4_old, opt4_new, n_warmup, n_iter) + results["OPT-4 Deepcopy"] = {"old": old_t, "new": new_t} + _print_paired(old_t, new_t) + + # --- OPT-6: Token cache --- + special_tokens = ["[SEP_STRUCT]", "[SEP_TEXT]", "[P]", "[C]", "[E]", "[R]", "[L]", + "[EXAMPLE]", "[OUTPUT]", "[DESCRIPTION]", "(", ")", ",", "|"] + cache = {tok: tokenizer.tokenize(tok) for tok in special_tokens} + test_tokens = special_tokens * 10 + + def opt6_old(): + for tok in test_tokens: + tokenizer.tokenize(tok) + + def opt6_new(): + for tok in test_tokens: + if tok in cache: + _ = cache[tok] + else: + tokenizer.tokenize(tok) + + print(" OPT-6 Token cache...", end=" ", flush=True) + old_t, new_t = _interleaved(opt6_old, opt6_new, n_warmup, n_iter) + results["OPT-6 Token cache"] = {"old": old_t, "new": new_t} + _print_paired(old_t, new_t) + + # --- OPT-12: Skip DataLoader --- + collator = ExtractorCollator(model.processor, is_training=False) + text_norm = "Apple CEO Tim Cook announced the iPhone 15 launch in Cupertino on September 12, 2023." + schema_e = model.create_schema().entities(["company", "person", "product", "location", "date"]) + sd = schema_e.build() + for c in sd.get("classifications", []): + c.setdefault("true_label", ["N/A"]) + small_dataset = [(text_norm, sd)] + + def opt12_old(): + loader = DataLoader(small_dataset, batch_size=8, shuffle=False, + num_workers=0, collate_fn=collator) + return list(loader) + + def opt12_new(): + return [collator(small_dataset)] + + print(" OPT-12 Skip DataLoader...", end=" ", flush=True) + old_t, new_t = _interleaved(opt12_old, opt12_new, n_warmup, n_iter) + results["OPT-12 Skip DataLoader"] = {"old": old_t, "new": new_t} + _print_paired(old_t, new_t) + + return results + + +def _interleaved(old_fn, new_fn, n_warmup, n_iter): + """Run old/new interleaved to eliminate ordering effects. Returns paired lists.""" + # Warmup both + for _ in range(n_warmup): + old_fn() + new_fn() + sync() + + old_times = [] + new_times = [] + for _ in range(n_iter): + # Randomize order each iteration to eliminate systematic bias + if random.random() < 0.5: + sync(); t0 = time.perf_counter(); old_fn(); sync() + old_times.append((time.perf_counter() - t0) * 1000) + sync(); t0 = time.perf_counter(); new_fn(); sync() + new_times.append((time.perf_counter() - t0) * 1000) + else: + sync(); t0 = time.perf_counter(); new_fn(); sync() + new_times.append((time.perf_counter() - t0) * 1000) + sync(); t0 = time.perf_counter(); old_fn(); sync() + old_times.append((time.perf_counter() - t0) * 1000) + + return old_times, new_times + + +def _print_paired(old_t, new_t): + m_old, m_new = statistics.mean(old_t), statistics.mean(new_t) + t_stat, p_val, mean_diff, hw = paired_test(old_t, new_t) + speedup = m_old / m_new if m_new > 0 else float('inf') + print(f"{m_old:.4f} -> {m_new:.4f} ms ({speedup:.1f}x) " + f"diff={mean_diff:.4f}Β±{hw:.4f}ms p={fmt_p(p_val)}") + + +# ─── Compare mode ──────────────────────────────────────────────── + +def compare(baseline_path, optimized_path): + """Compare two end-to-end result files with Welch's t-test.""" + with open(baseline_path) as f: + baseline = json.load(f) + with open(optimized_path) as f: + optimized = json.load(f) + + print(f"\nBaseline: {baseline_path} (device={baseline['device']}, n={baseline.get('n', '?')})") + print(f"Optimized: {optimized_path} (device={optimized['device']}, n={optimized.get('n', '?')})") + + print(f"\n{'Scenario':<25} {'Baseline':>18} {'Optimized':>18} {'Diff':>14} {'Speedup':>8} {'p-value':>10}") + print("=" * 100) + + for name in baseline["e2e"]: + b = baseline["e2e"][name] + o = optimized["e2e"][name] + + m_b, ci_b = statistics.mean(b), ci95(b) + m_o, ci_o = statistics.mean(o), ci95(o) + diff = m_b - m_o + diff_ci = math.sqrt(ci_b**2 + ci_o**2) # approximate CI of difference + speedup = m_b / m_o if m_o > 0 else float('inf') + t_stat, p_val = welch_test(b, o) + + sig = "*" if p_val < 0.05 else " " + if p_val < 0.01: + sig = "**" + if p_val < 0.001: + sig = "***" + + print(f"{name:<25} {m_b:>7.2f}Β±{ci_b:>5.2f}ms {m_o:>7.2f}Β±{ci_o:>5.2f}ms " + f"{diff:>+6.2f}Β±{diff_ci:>4.2f}ms {speedup:>7.3f}x {fmt_p(p_val):>9}{sig}") + + # Micro-benchmarks (if present in optimized) + if "micro" in optimized: + print(f"\n{'Component':<30} {'Old':>16} {'New':>16} {'Diff (paired)':>18} {'Speedup':>8} {'p-value':>10}") + print("=" * 105) + + for name, data in optimized["micro"].items(): + old_t = data["old"] + new_t = data["new"] + m_old, ci_old = statistics.mean(old_t), ci95(old_t) + m_new, ci_new = statistics.mean(new_t), ci95(new_t) + t_stat, p_val, mean_diff, hw = paired_test(old_t, new_t) + speedup = m_old / m_new if m_new > 0 else float('inf') + + sig = "*" if p_val < 0.05 else " " + if p_val < 0.01: sig = "**" + if p_val < 0.001: sig = "***" + + print(f"{name:<30} {m_old:>6.4f}Β±{ci_old:>6.4f}ms {m_new:>6.4f}Β±{ci_new:>6.4f}ms " + f"{mean_diff:>+7.4f}Β±{hw:>6.4f}ms {speedup:>7.1f}x {fmt_p(p_val):>9}{sig}") + + +# ─── Main ──────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--tag", help="Tag for this run (baseline or optimized)") + parser.add_argument("--n", type=int, default=300, help="Iterations per scenario") + parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations") + parser.add_argument("--compare", nargs=2, metavar=("BASELINE", "OPTIMIZED"), + help="Compare two result files") + args = parser.parse_args() + + if args.compare: + compare( + f"bench_stats_{args.compare[0]}.json", + f"bench_stats_{args.compare[1]}.json" + ) + return + + if not args.tag: + parser.error("--tag is required (or use --compare)") + + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Device: {device}") + print(f"Iterations: {args.n}, Warmup: {args.warmup}\n") + + output = {"tag": args.tag, "device": device, "n": args.n} + + # End-to-end + print("END-TO-END BENCHMARKS") + print("-" * 60) + e2e = run_e2e(args.n, args.warmup) + output["e2e"] = e2e + + # Micro-benchmarks (only meaningful for optimized run since we inline both versions) + print("\nCOMPONENT MICRO-BENCHMARKS (interleaved old/new)") + print("-" * 60) + micro = run_micro(args.n, args.warmup) + output["micro"] = {k: v for k, v in micro.items()} + + out_path = f"bench_stats_{args.tag}.json" + with open(out_path, "w") as f: + json.dump(output, f) + print(f"\nRaw timings saved to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/packages/GLiNER2/gliner2/__init__.py b/packages/GLiNER2/gliner2/__init__.py new file mode 100644 index 0000000..7364220 --- /dev/null +++ b/packages/GLiNER2/gliner2/__init__.py @@ -0,0 +1,23 @@ +__version__ = "1.2.4" + +from .inference.engine import GLiNER2, RegexValidator +from .model import Extractor, ExtractorConfig +from .api_client import ( + GLiNER2API, + GLiNER2APIError, + AuthenticationError, + ValidationError, + ServerError, +) +from .training.lora import ( + LoRAConfig, + LoRAAdapterConfig, + LoRALayer, + load_lora_adapter, + save_lora_adapter, + unload_lora_adapter, + has_lora_adapter, + apply_lora_to_model, + merge_lora_weights, + unmerge_lora_weights, +) \ No newline at end of file diff --git a/packages/GLiNER2/gliner2/api_client.py b/packages/GLiNER2/gliner2/api_client.py new file mode 100644 index 0000000..5cc8306 --- /dev/null +++ b/packages/GLiNER2/gliner2/api_client.py @@ -0,0 +1,989 @@ +""" +GLiNER2 API Client + +This module provides an API-based wrapper for GLiNER2 that mirrors the local +model interface. It allows seamless switching between local and API-based +inference. + +Usage: + >>> from gliner2 import GLiNER2 + >>> + >>> # Load from API (uses environment variable for API key) + >>> extractor = GLiNER2.from_api() + >>> + >>> # Use exactly like local model + >>> results = extractor.extract_entities( + ... "Apple released iPhone 15 in September 2023.", + ... ["company", "product", "date"] + ... ) +""" + +from __future__ import annotations + +import os +import logging +import warnings +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Union, Literal +from urllib.parse import urljoin +from urllib3.util import Retry +import requests +from requests.adapters import HTTPAdapter + +logger = logging.getLogger(__name__) + + +class GLiNER2APIError(Exception): + """Base exception for GLiNER2 API errors.""" + + def __init__(self, message: str, status_code: Optional[int] = None, response_data: Optional[Dict] = None): + super().__init__(message) + self.status_code = status_code + self.response_data = response_data + + +class AuthenticationError(GLiNER2APIError): + """Raised when API key is invalid or expired.""" + pass + + +class ValidationError(GLiNER2APIError): + """Raised when request data is invalid.""" + pass + + +class ServerError(GLiNER2APIError): + """Raised when server encounters an error.""" + pass + + +class StructureBuilderAPI: + """ + Builder for structured data schemas for API-based extraction. + + This mirrors the interface of StructureBuilder from the local model. + """ + + def __init__(self, schema: 'SchemaAPI', parent: str): + self.schema = schema + self.parent = parent + self.fields = OrderedDict() + self.field_order = [] + self._finished = False + + def field( + self, + name: str, + dtype: Literal["str", "list"] = "list", + choices: Optional[List[str]] = None, + description: Optional[str] = None, + threshold: Optional[float] = None, + validators: Optional[List] = None + ) -> 'StructureBuilderAPI': + """Add a field to the structured data.""" + # Warn if validators are used (not supported in API mode) + if validators: + warnings.warn( + f"Field '{name}': RegexValidator is not supported in API mode. " + "Validators will be ignored. Use local model for regex-based filtering.", + UserWarning, + stacklevel=2 + ) + + self.fields[name] = { + "dtype": dtype, + "choices": choices, + "description": description, + "threshold": threshold + } + self.field_order.append(name) + return self + + def _auto_finish(self): + """Automatically finish this structure when needed.""" + if not self._finished: + # Convert fields to API format + # Use dict format if any field has threshold or choices (advanced features) + # Otherwise use simple string format for backwards compatibility + field_specs = [] + for name in self.field_order: + config = self.fields[name] + + # Check if advanced features are used + has_threshold = config.get('threshold') is not None + has_choices = config.get('choices') is not None + + if has_threshold or has_choices: + # Use dict format for advanced features + field_dict = {"name": name, "dtype": config['dtype']} + if config.get('description'): + field_dict["description"] = config['description'] + if has_threshold: + field_dict["threshold"] = config['threshold'] + if has_choices: + field_dict["choices"] = config['choices'] + field_specs.append(field_dict) + else: + # Use simple string format: "name::type::description" + spec = f"{name}::{config['dtype']}" + if config.get('description'): + spec += f"::{config['description']}" + field_specs.append(spec) + + self.schema._structures[self.parent] = field_specs + self._finished = True + + def __getattr__(self, name): + """Auto-finish when any schema method is called.""" + if hasattr(self.schema, name): + self._auto_finish() + return getattr(self.schema, name) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + +class SchemaAPI: + """Schema builder for API-based extraction tasks.""" + + def __init__(self): + self._entities = None + self._entity_dtype = "list" + self._entity_threshold = None + self._classifications = {} + self._structures = {} + self._relations = None + self._relation_threshold = None + self._active_structure_builder = None + + def entities( + self, + entity_types: Union[str, List[str], Dict[str, Union[str, Dict]]], + dtype: Literal["str", "list"] = "list", + threshold: Optional[float] = None + ) -> 'SchemaAPI': + """Add entity extraction task.""" + if self._active_structure_builder: + self._active_structure_builder._auto_finish() + self._active_structure_builder = None + + # Normalize to list or dict + if isinstance(entity_types, str): + self._entities = [entity_types] + elif isinstance(entity_types, list): + self._entities = entity_types + elif isinstance(entity_types, dict): + self._entities = entity_types + + self._entity_dtype = dtype + self._entity_threshold = threshold + return self + + def classification( + self, + task: str, + labels: Union[List[str], Dict[str, str]], + multi_label: bool = False, + cls_threshold: float = 0.5, + **kwargs + ) -> 'SchemaAPI': + """Add a text classification task.""" + if self._active_structure_builder: + self._active_structure_builder._auto_finish() + self._active_structure_builder = None + + # Parse labels + if isinstance(labels, dict): + label_names = list(labels.keys()) + else: + label_names = labels + + self._classifications[task] = { + "labels": label_names, + "multi_label": multi_label, + "cls_threshold": cls_threshold + } + return self + + def structure(self, name: str) -> StructureBuilderAPI: + """Start building a structured data schema.""" + if self._active_structure_builder: + self._active_structure_builder._auto_finish() + + self._active_structure_builder = StructureBuilderAPI(self, name) + return self._active_structure_builder + + def relations( + self, + relation_types: Union[str, List[str], Dict[str, Union[str, Dict]]], + threshold: Optional[float] = None + ) -> 'SchemaAPI': + """ + Add relation extraction task. + + Args: + relation_types: Relation types to extract. Can be: + - str: Single relation type + - List[str]: Multiple relation types + - Dict[str, str]: Relation types with descriptions + - Dict[str, Dict]: Relation types with full configuration + threshold: Default confidence threshold for relations. + + Returns: + Self for method chaining. + """ + if self._active_structure_builder: + self._active_structure_builder._auto_finish() + self._active_structure_builder = None + + # Normalize to list or dict + if isinstance(relation_types, str): + self._relations = [relation_types] + elif isinstance(relation_types, list): + self._relations = relation_types + elif isinstance(relation_types, dict): + self._relations = relation_types + + self._relation_threshold = threshold + return self + + def build(self) -> Dict[str, Any]: + """Build the schema for API request.""" + if self._active_structure_builder: + self._active_structure_builder._auto_finish() + self._active_structure_builder = None + + schema = {} + + if self._entities is not None: + schema["entities"] = self._entities + schema["entity_dtype"] = self._entity_dtype + if self._entity_threshold is not None: + schema["entity_threshold"] = self._entity_threshold + + if self._classifications: + schema["classifications"] = self._classifications + + if self._structures: + schema["structures"] = self._structures + + if self._relations is not None: + schema["relations"] = self._relations + if self._relation_threshold is not None: + schema["relation_threshold"] = self._relation_threshold + + return schema + + +class GLiNER2API: + """ + API-based GLiNER2 client that mirrors the local model interface. + + This class provides the same methods as GLiNER2 but makes HTTP requests + to the API endpoint instead of running local inference. + + Attributes: + api_key: API authentication key + base_url: API base URL + timeout: Request timeout in seconds + max_retries: Maximum number of retries for failed requests + """ + + DEFAULT_BASE_URL = "https://api.fastino.ai" + + def __init__( + self, + api_key: Optional[str] = None, + api_base_url: Optional[str] = None, + timeout: float = 30.0, + max_retries: int = 3, + ): + """ + Initialize the GLiNER2 API client. + + Args: + api_key: API authentication key. If not provided, reads from + PIONEER_API_KEY environment variable. + api_base_url: Override the default API base URL. + timeout: Request timeout in seconds. + max_retries: Maximum number of retries for failed requests. + + Raises: + ValueError: If no API key is provided and PIONEER_API_KEY is not set. + """ + # Read API key from environment if not provided + if api_key is None: + api_key = os.environ.get("PIONEER_API_KEY") + if api_key is None: + raise ValueError( + "API key must be provided either as an argument or via " + "PIONEER_API_KEY environment variable" + ) + + self.api_key = api_key + self.base_url = api_base_url or os.environ.get( + "GLINER2_API_BASE_URL", self.DEFAULT_BASE_URL + ) + self.timeout = timeout + self.max_retries = max_retries + + # Setup HTTP session with retry logic + self.session = requests.Session() + self.session.headers.update({ + "X-API-Key": api_key, + "Content-Type": "application/json", + }) + + # Configure retry strategy + retry_strategy = Retry( + total=max_retries, + backoff_factor=1, # 1s, 2s, 4s backoff + status_forcelist=[429, 500, 502, 503, 504], + allowed_methods=["POST"], + ) + adapter = HTTPAdapter(max_retries=retry_strategy) + self.session.mount("https://", adapter) + self.session.mount("http://", adapter) + + logger.debug(f"Initialized GLiNER2API for {self.base_url}") + + def _make_request( + self, + task: str, + text: Union[str, List[str]], + schema: Union[List[str], Dict], + threshold: float = 0.5, + include_confidence: bool = False, + include_spans: bool = False, + format_results: bool = True, + ) -> Dict[str, Any]: + """ + Make an HTTP request to the GLiNER-2 API. + + Args: + task: Task type (extract_entities, classify_text, extract_json, schema) + text: Text to process (string or list for batch) + schema: Schema for extraction + threshold: Confidence threshold + include_confidence: Whether to include confidence scores in results + include_spans: Whether to include character-level start/end positions + format_results: Whether to format results (False for raw extraction data) + + Returns: + API response result + + Raises: + GLiNER2APIError: If request fails + """ + # Ensure base_url ends with / for proper joining + base = self.base_url.rstrip('/') + '/' + url = urljoin(base, "gliner-2") + + payload = { + "task": task, + "text": text, + "schema": schema, + "threshold": threshold, + "include_confidence": include_confidence, + "include_spans": include_spans, + "format_results": format_results, + } + + logger.debug(f"Making POST request to {url}") + + try: + response = self.session.post( + url, + json=payload, + timeout=self.timeout, + ) + + logger.debug(f"Response status: {response.status_code}") + + # Handle different error codes + if response.status_code == 401: + error_data = response.json() if response.content else None + error_msg = ( + error_data.get("detail", "Invalid or expired API key") + if error_data else "Invalid or expired API key" + ) + raise AuthenticationError(error_msg, response_data=error_data) + + elif response.status_code in (400, 422): + error_data = response.json() if response.content else None + error_msg = ( + error_data.get("detail", "Request validation failed") + if error_data else "Request validation failed" + ) + raise ValidationError( + error_msg, + status_code=response.status_code, + response_data=error_data, + ) + + elif response.status_code >= 500: + error_data = response.json() if response.content else None + error_msg = ( + error_data.get("detail", "Server error occurred") + if error_data else "Server error occurred" + ) + raise ServerError( + error_msg, + status_code=response.status_code, + response_data=error_data, + ) + + elif not response.ok: + error_data = response.json() if response.content else None + error_msg = ( + error_data.get("detail", f"Request failed with status {response.status_code}") + if error_data else f"Request failed with status {response.status_code}" + ) + raise GLiNER2APIError( + error_msg, + status_code=response.status_code, + response_data=error_data, + ) + + data = response.json() + return data.get("result", data) + + except requests.exceptions.Timeout: + raise GLiNER2APIError(f"Request timed out after {self.timeout}s") + except requests.exceptions.ConnectionError as e: + raise GLiNER2APIError(f"Connection error: {str(e)}") + except requests.exceptions.RequestException as e: + raise GLiNER2APIError(f"Request failed: {str(e)}") + + def create_schema(self) -> SchemaAPI: + """Create a new schema for defining extraction tasks.""" + return SchemaAPI() + + # ------------------------------------------------------------------------- + # Entity Extraction Methods + # ------------------------------------------------------------------------- + + def extract_entities( + self, + text: str, + entity_types: Union[List[str], Dict[str, Union[str, Dict]]], + threshold: float = 0.5, + format_results: bool = True, + include_confidence: bool = False, + include_spans: bool = False + ) -> Dict[str, Any]: + """ + Extract entities from text. + + Args: + text: Input text to extract entities from. + entity_types: List of entity types or dict with descriptions. + threshold: Minimum confidence threshold. + format_results: Whether to format results. If False, returns raw extraction data. + include_confidence: Whether to include confidence scores in results. + include_spans: Whether to include character-level start/end positions. + + Returns: + Dictionary with "entities" key containing extracted entities. + If include_confidence=True, entity values include confidence scores. + If include_spans=True, entity values include start/end positions. + If format_results=False, returns raw extraction data with positions. + """ + # Normalize entity types to list + if isinstance(entity_types, dict): + entities = list(entity_types.keys()) + else: + entities = entity_types + + result = self._make_request( + task="extract_entities", + text=text, + schema=entities, + threshold=threshold, + include_confidence=include_confidence, + include_spans=include_spans, + format_results=format_results, + ) + + # Wrap result in expected format if needed (only for formatted results) + if format_results and isinstance(result, dict) and "entities" not in result: + return {"entities": result} + return result + + def batch_extract_entities( + self, + texts: List[str], + entity_types: Union[List[str], Dict[str, Union[str, Dict]]], + batch_size: int = 8, + threshold: float = 0.5, + format_results: bool = True, + include_confidence: bool = False, + include_spans: bool = False + ) -> List[Dict[str, Any]]: + """ + Batch extract entities from multiple texts. + + Args: + texts: List of input texts. + entity_types: List of entity types or dict with descriptions. + batch_size: Batch size (used by API for optimization). + threshold: Minimum confidence threshold. + format_results: Whether to format results. If False, returns raw extraction data. + include_confidence: Whether to include confidence scores. + include_spans: Whether to include character-level start/end positions. + + Returns: + List of dictionaries with "entities" key. + If include_confidence=True, entity values include confidence scores. + If include_spans=True, entity values include start/end positions. + If format_results=False, returns raw extraction data with positions. + """ + # Normalize entity types to list + if isinstance(entity_types, dict): + entities = list(entity_types.keys()) + else: + entities = entity_types + + result = self._make_request( + task="extract_entities", + text=texts, + schema=entities, + threshold=threshold, + include_confidence=include_confidence, + include_spans=include_spans, + format_results=format_results, + ) + + # Ensure result is a list + if isinstance(result, dict): + return [result] + return result + + # ------------------------------------------------------------------------- + # Text Classification Methods + # ------------------------------------------------------------------------- + + def classify_text( + self, + text: str, + tasks: Dict[str, Union[List[str], Dict[str, Any]]], + threshold: float = 0.5, + format_results: bool = True, + include_confidence: bool = False, + include_spans: bool = False + ) -> Dict[str, Any]: + """ + Classify text into categories. + + Args: + text: Text to classify. + tasks: Classification tasks where keys are task names. + threshold: Confidence threshold. + format_results: Whether to format results. If False, returns raw extraction data. + include_confidence: Whether to include confidence scores. + include_spans: Whether to include character-level start/end positions. + + Returns: + Classification results keyed by task name. + If include_confidence=True, results include confidence scores. + If format_results=False, returns raw extraction data. + """ + # Convert tasks to API format + # For classify_text task, schema should be {"categories": [...]} + # But for multi-task, we need to use the schema task + if len(tasks) == 1: + # Single task - use classify_text endpoint + task_name = list(tasks.keys())[0] + task_config = tasks[task_name] + + if isinstance(task_config, dict) and "labels" in task_config: + categories = task_config["labels"] + else: + categories = task_config + + result = self._make_request( + task="classify_text", + text=text, + schema={"categories": categories}, + threshold=threshold, + include_confidence=include_confidence, + include_spans=include_spans, + format_results=format_results, + ) + + # Wrap result with task name (only for formatted results) + if format_results and isinstance(result, dict) and task_name not in result: + return {task_name: result.get("classification", result)} + return result + else: + # Multiple tasks - use schema endpoint + schema = {"classifications": tasks} + result = self._make_request( + task="schema", + text=text, + schema=schema, + threshold=threshold, + include_confidence=include_confidence, + include_spans=include_spans, + format_results=format_results, + ) + return result + + def batch_classify_text( + self, + texts: List[str], + tasks: Dict[str, Union[List[str], Dict[str, Any]]], + batch_size: int = 8, + threshold: float = 0.5, + format_results: bool = True, + include_confidence: bool = False, + include_spans: bool = False + ) -> List[Dict[str, Any]]: + """ + Batch classify multiple texts. + + Args: + texts: List of texts to classify. + tasks: Classification tasks. + batch_size: Batch size. + threshold: Confidence threshold. + format_results: Whether to format results. If False, returns raw extraction data. + include_confidence: Whether to include confidence scores. + include_spans: Whether to include character-level start/end positions. + + Returns: + List of classification results. + If include_confidence=True, results include confidence scores. + If format_results=False, returns raw extraction data. + """ + # Use schema task for batch classification + schema = {"classifications": tasks} + result = self._make_request( + task="schema", + text=texts, + schema=schema, + threshold=threshold, + include_confidence=include_confidence, + include_spans=include_spans, + format_results=format_results, + ) + + if isinstance(result, dict): + return [result] + return result + + # ------------------------------------------------------------------------- + # JSON Extraction Methods + # ------------------------------------------------------------------------- + + def extract_json( + self, + text: str, + structures: Dict[str, List[str]], + threshold: float = 0.5, + format_results: bool = True, + include_confidence: bool = False, + include_spans: bool = False + ) -> Dict[str, Any]: + """ + Extract structured data from text. + + Args: + text: Text to extract data from. + structures: Structure definitions with field specs. + threshold: Minimum confidence threshold. + format_results: Whether to format results. If False, returns raw extraction data. + include_confidence: Whether to include confidence scores. + include_spans: Whether to include character-level start/end positions. + + Returns: + Extracted structures keyed by structure name. + If include_confidence=True, field values include confidence scores. + If include_spans=True, field values include start/end positions. + If format_results=False, returns raw extraction data with positions. + """ + result = self._make_request( + task="extract_json", + text=text, + schema=structures, + threshold=threshold, + include_confidence=include_confidence, + include_spans=include_spans, + format_results=format_results, + ) + return result + + def batch_extract_json( + self, + texts: List[str], + structures: Dict[str, List[str]], + batch_size: int = 8, + threshold: float = 0.5, + format_results: bool = True, + include_confidence: bool = False, + include_spans: bool = False + ) -> List[Dict[str, Any]]: + """ + Batch extract structured data from multiple texts. + + Args: + texts: List of texts. + structures: Structure definitions. + batch_size: Batch size. + threshold: Confidence threshold. + format_results: Whether to format results. If False, returns raw extraction data. + include_confidence: Whether to include confidence scores. + include_spans: Whether to include character-level start/end positions. + + Returns: + List of extracted structures. + If include_confidence=True, field values include confidence scores. + If include_spans=True, field values include start/end positions. + If format_results=False, returns raw extraction data with positions. + """ + result = self._make_request( + task="extract_json", + text=texts, + schema=structures, + threshold=threshold, + include_confidence=include_confidence, + include_spans=include_spans, + format_results=format_results, + ) + + if isinstance(result, dict): + return [result] + return result + + # ------------------------------------------------------------------------- + # Relation Extraction Methods + # ------------------------------------------------------------------------- + + def extract_relations( + self, + text: str, + relation_types: Union[str, List[str], Dict[str, Union[str, Dict]]], + threshold: float = 0.5, + format_results: bool = True, + include_confidence: bool = False, + include_spans: bool = False + ) -> Dict[str, Any]: + """ + Extract relations between entities from text. + + Args: + text: Input text to extract relations from. + relation_types: Relation types to extract. Can be: + - str: Single relation type + - List[str]: Multiple relation types + - Dict[str, str]: Relation types with descriptions + - Dict[str, Dict]: Relation types with full configuration + threshold: Minimum confidence threshold. + format_results: Whether to format results. If False, returns raw extraction data. + include_confidence: Whether to include confidence scores in results. + include_spans: Whether to include character-level start/end positions. + + Returns: + Dictionary with "relation_extraction" key containing extracted relations. + Relations are grouped by type with tuples (source, target). + Format: {"relation_extraction": {"relation_name": [("source", "target"), ...]}} + """ + # Build schema with relations + schema = self.create_schema().relations(relation_types).build() + + result = self._make_request( + task="schema", + text=text, + schema=schema, + threshold=threshold, + include_confidence=include_confidence, + include_spans=include_spans, + format_results=format_results, + ) + + return result + + def batch_extract_relations( + self, + texts: List[str], + relation_types: Union[str, List[str], Dict[str, Union[str, Dict]]], + batch_size: int = 8, + threshold: float = 0.5, + format_results: bool = True, + include_confidence: bool = False, + include_spans: bool = False + ) -> List[Dict[str, Any]]: + """ + Batch extract relations from multiple texts. + + Args: + texts: List of input texts. + relation_types: Relation types to extract. + batch_size: Batch size (used by API for optimization). + threshold: Minimum confidence threshold. + format_results: Whether to format results. + include_confidence: Whether to include confidence scores. + include_spans: Whether to include character-level start/end positions. + + Returns: + List of dictionaries with "relation_extraction" key. + Format: [{"relation_extraction": {"relation_name": [("source", "target"), ...]}}] + """ + # Build schema with relations + schema = self.create_schema().relations(relation_types).build() + + result = self._make_request( + task="schema", + text=texts, + schema=schema, + threshold=threshold, + include_confidence=include_confidence, + include_spans=include_spans, + format_results=format_results, + ) + + # Ensure result is a list + if isinstance(result, dict): + return [result] + return result + + # ------------------------------------------------------------------------- + # General Extraction Methods + # ------------------------------------------------------------------------- + + def extract( + self, + text: str, + schema: Union[SchemaAPI, Dict[str, Any]], + threshold: float = 0.5, + format_results: bool = True, + include_confidence: bool = False, + include_spans: bool = False + ) -> Dict[str, Any]: + """ + Extract information from text using a schema. + + Args: + text: Input text to extract from. + schema: Schema defining what to extract. + threshold: Minimum confidence threshold. + format_results: Whether to format results. If False, returns raw extraction data. + include_confidence: Whether to include confidence scores. + include_spans: Whether to include character-level start/end positions. + + Returns: + Extraction results organized by task name. + If include_confidence=True, values include confidence scores. + If include_spans=True, values include start/end positions. + If format_results=False, returns raw extraction data with positions. + """ + # Build schema dict if needed + if isinstance(schema, SchemaAPI): + schema_dict = schema.build() + elif hasattr(schema, 'build'): + schema_dict = schema.build() + else: + schema_dict = schema + + # Validate schema has at least one extraction task + has_any_task = any( + key in schema_dict + for key in ["entities", "classifications", "structures", "relations"] + ) + if not has_any_task: + raise ValueError("Schema must contain at least one extraction task") + + # Always use schema task to preserve all metadata (thresholds, dtypes, etc.) + return self._make_request( + task="schema", + text=text, + schema=schema_dict, + threshold=threshold, + include_confidence=include_confidence, + include_spans=include_spans, + format_results=format_results, + ) + + def batch_extract( + self, + texts: List[str], + schemas: Union[SchemaAPI, List[SchemaAPI], Dict[str, Any], List[Dict[str, Any]]], + batch_size: int = 8, + threshold: float = 0.5, + format_results: bool = True, + include_confidence: bool = False, + include_spans: bool = False + ) -> List[Dict[str, Any]]: + """ + Extract information from multiple texts. + + Args: + texts: List of input texts. + schemas: Single schema for all texts or list of schemas. + batch_size: Batch size. + threshold: Confidence threshold. + format_results: Whether to format results. If False, returns raw extraction data. + include_confidence: Whether to include confidence scores. + include_spans: Whether to include character-level start/end positions. + + Returns: + List of extraction results. + If include_confidence=True, values include confidence scores. + If include_spans=True, values include start/end positions. + If format_results=False, returns raw extraction data with positions. + """ + if not texts: + return [] + + # Handle schema variations + if isinstance(schemas, list): + if len(schemas) != len(texts): + raise ValueError( + f"Number of schemas ({len(schemas)}) must match number of texts ({len(texts)})" + ) + # Warn user about multi-schema batch limitation + warnings.warn( + "Multi-schema batch (different schemas per text) is not natively supported by the API. " + "Each text will be processed individually, which may be slower than single-schema batch. " + "For better performance, use the same schema for all texts.", + UserWarning, + stacklevel=2 + ) + # Process each text with its schema individually + results = [] + for text, schema in zip(texts, schemas): + results.append(self.extract(text, schema, threshold, include_confidence=include_confidence, include_spans=include_spans, format_results=format_results)) + return results + + # Single schema for all texts + if isinstance(schemas, SchemaAPI): + schema_dict = schemas.build() + elif hasattr(schemas, 'build'): + schema_dict = schemas.build() + else: + schema_dict = schemas + + return self._make_request( + task="schema", + text=texts, + schema=schema_dict, + threshold=threshold, + include_confidence=include_confidence, + include_spans=include_spans, + format_results=format_results, + ) + + # ------------------------------------------------------------------------- + # Utility Methods + # ------------------------------------------------------------------------- + + def close(self): + """Close the HTTP session.""" + self.session.close() + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.close() + diff --git a/packages/GLiNER2/gliner2/inference/__init__.py b/packages/GLiNER2/gliner2/inference/__init__.py new file mode 100644 index 0000000..1237600 --- /dev/null +++ b/packages/GLiNER2/gliner2/inference/__init__.py @@ -0,0 +1 @@ +from .engine import RegexValidator, GLiNER2 \ No newline at end of file diff --git a/packages/GLiNER2/gliner2/inference/engine.py b/packages/GLiNER2/gliner2/inference/engine.py new file mode 100644 index 0000000..f6ab5e0 --- /dev/null +++ b/packages/GLiNER2/gliner2/inference/engine.py @@ -0,0 +1,1458 @@ +""" +GLiNER2 - Advanced Information Extraction Engine + +This module provides the main GLiNER2 class with optimized batch processing +using DataLoader-based parallel preprocessing. + +Example: + >>> from gliner2 import GLiNER2 + >>> + >>> extractor = GLiNER2.from_pretrained("model-repo") + >>> + >>> # Simple extraction + >>> results = extractor.extract_entities( + ... "Apple released iPhone 15.", + ... ["company", "product"] + ... ) + >>> + >>> # Batch extraction (parallel preprocessing) + >>> results = extractor.batch_extract_entities( + ... texts_list, + ... ["company", "product"], + ... batch_size=32, + ... num_workers=4 + ... ) +""" + +from __future__ import annotations + +import re +import hashlib +import json +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union, Tuple, TYPE_CHECKING, Pattern, Literal + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from gliner2.model import Extractor +from gliner2.processor import PreprocessedBatch +from gliner2.training.trainer import ExtractorCollator + +if TYPE_CHECKING: + from gliner2.api_client import GLiNER2API + + +# ============================================================================= +# Validators +# ============================================================================= + +@dataclass +class RegexValidator: + """Regex-based span filter for post-processing.""" + pattern: str | Pattern[str] + mode: Literal["full", "partial"] = "full" + exclude: bool = False + flags: int = re.IGNORECASE + _compiled: Pattern[str] = field(init=False, repr=False) + + def __post_init__(self): + if self.mode not in {"full", "partial"}: + raise ValueError(f"mode must be 'full' or 'partial', got {self.mode!r}") + try: + compiled = ( + self.pattern if isinstance(self.pattern, re.Pattern) + else re.compile(self.pattern, self.flags) + ) + except re.error as err: + raise ValueError(f"Invalid regex: {self.pattern!r}") from err + object.__setattr__(self, "_compiled", compiled) + + def __call__(self, text: str) -> bool: + return self.validate(text) + + def validate(self, text: str) -> bool: + matcher = self._compiled.fullmatch if self.mode == "full" else self._compiled.search + matched = matcher(text) is not None + return not matched if self.exclude else matched + + +# ============================================================================= +# Schema Builder +# ============================================================================= + +class StructureBuilder: + """Builder for structured data schemas.""" + + def __init__(self, schema: 'Schema', parent: str): + self.schema = schema + self.parent = parent + self.fields = OrderedDict() + self.descriptions = OrderedDict() + self.field_order = [] + self._finished = False + + def field( + self, + name: str, + dtype: Literal["str", "list"] = "list", + choices: Optional[List[str]] = None, + description: Optional[str] = None, + threshold: Optional[float] = None, + validators: Optional[List[RegexValidator]] = None + ) -> 'StructureBuilder': + """Add a field to the structure.""" + self.fields[name] = {"value": "", "choices": choices} if choices else "" + self.field_order.append(name) + + if description: + self.descriptions[name] = description + + self.schema._store_field_metadata(self.parent, name, dtype, threshold, choices, validators) + return self + + def _auto_finish(self): + if not self._finished: + self.schema._store_field_order(self.parent, self.field_order) + self.schema.schema["json_structures"].append({self.parent: self.fields}) + + if self.descriptions: + if "json_descriptions" not in self.schema.schema: + self.schema.schema["json_descriptions"] = {} + self.schema.schema["json_descriptions"][self.parent] = self.descriptions + + self._finished = True + + def __getattr__(self, name): + if hasattr(self.schema, name): + self._auto_finish() + return getattr(self.schema, name) + raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") + + +class Schema: + """Schema builder for extraction tasks.""" + + def __init__(self): + self.schema = { + "json_structures": [], + "classifications": [], + "entities": OrderedDict(), + "relations": [], + "json_descriptions": {}, + "entity_descriptions": OrderedDict() + } + self._field_metadata = {} + self._entity_metadata = {} + self._relation_metadata = {} + self._field_orders = {} + self._entity_order = [] + self._relation_order = [] + self._active_builder = None + + def _store_field_metadata(self, parent, field, dtype, threshold, choices, validators=None): + if threshold is not None and not 0 <= threshold <= 1: + raise ValueError(f"Threshold must be 0-1, got {threshold}") + self._field_metadata[f"{parent}.{field}"] = { + "dtype": dtype, "threshold": threshold, "choices": choices, + "validators": validators or [] + } + + def _store_entity_metadata(self, entity, dtype, threshold): + if threshold is not None and not 0 <= threshold <= 1: + raise ValueError(f"Threshold must be 0-1, got {threshold}") + self._entity_metadata[entity] = {"dtype": dtype, "threshold": threshold} + + def _store_field_order(self, parent, order): + self._field_orders[parent] = order + + def structure(self, name: str) -> StructureBuilder: + """Start building a structure schema.""" + if self._active_builder: + self._active_builder._auto_finish() + self._active_builder = StructureBuilder(self, name) + return self._active_builder + + def classification( + self, + task: str, + labels: Union[List[str], Dict[str, str]], + multi_label: bool = False, + cls_threshold: float = 0.5, + **kwargs + ) -> 'Schema': + """Add classification task.""" + if self._active_builder: + self._active_builder._auto_finish() + self._active_builder = None + + label_names = list(labels.keys()) if isinstance(labels, dict) else labels + label_descs = labels if isinstance(labels, dict) else None + + config = { + "task": task, "labels": label_names, + "multi_label": multi_label, "cls_threshold": cls_threshold, + "true_label": ["N/A"], **kwargs + } + if label_descs: + config["label_descriptions"] = label_descs + + self.schema["classifications"].append(config) + return self + + def entities( + self, + entity_types: Union[str, List[str], Dict[str, Union[str, Dict]]], + dtype: Literal["str", "list"] = "list", + threshold: Optional[float] = None + ) -> 'Schema': + """Add entity extraction task.""" + if self._active_builder: + self._active_builder._auto_finish() + self._active_builder = None + + entities = self._parse_entity_input(entity_types) + + for name, config in entities.items(): + self.schema["entities"][name] = "" + if name not in self._entity_order: + self._entity_order.append(name) + + self._store_entity_metadata( + name, + config.get("dtype", dtype), + config.get("threshold", threshold) + ) + + if "description" in config: + self.schema["entity_descriptions"][name] = config["description"] + + return self + + def _parse_entity_input(self, entity_types): + if isinstance(entity_types, str): + return {entity_types: {}} + elif isinstance(entity_types, list): + return {name: {} for name in entity_types} + elif isinstance(entity_types, dict): + result = {} + for name, config in entity_types.items(): + if isinstance(config, str): + result[name] = {"description": config} + elif isinstance(config, dict): + result[name] = config + else: + result[name] = {} + return result + raise ValueError("Invalid entity_types format") + + def relations( + self, + relation_types: Union[str, List[str], Dict[str, Union[str, Dict]]], + threshold: Optional[float] = None + ) -> 'Schema': + """Add relation extraction task.""" + if self._active_builder: + self._active_builder._auto_finish() + self._active_builder = None + + if isinstance(relation_types, str): + relations = {relation_types: {}} + elif isinstance(relation_types, list): + relations = {name: {} for name in relation_types} + elif isinstance(relation_types, dict): + relations = {} + for name, config in relation_types.items(): + relations[name] = {"description": config} if isinstance(config, str) else (config if isinstance(config, dict) else {}) + else: + raise ValueError("Invalid relation_types format") + + for name, config in relations.items(): + self.schema["relations"].append({name: {"head": "", "tail": ""}}) + if name not in self._relation_order: + self._relation_order.append(name) + self._field_orders[name] = ["head", "tail"] + + rel_threshold = config.get("threshold", threshold) + if rel_threshold is not None and not 0 <= rel_threshold <= 1: + raise ValueError(f"Threshold must be 0-1, got {rel_threshold}") + self._relation_metadata[name] = {"threshold": rel_threshold} + + return self + + def build(self) -> Dict[str, Any]: + """Build final schema dictionary.""" + if self._active_builder: + self._active_builder._auto_finish() + self._active_builder = None + return self.schema + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'Schema': + """Create a Schema from a dictionary. + + Args: + data: Dictionary with optional keys: entities, structures, + classifications, relations + + Returns: + Schema: Constructed schema instance + + Raises: + ValidationError: If the input data is invalid + + Example: + >>> schema_dict = { + ... "entities": ["company", "person"], + ... "structures": { + ... "product_info": { + ... "fields": [ + ... {"name": "company", "dtype": "str"}, + ... {"name": "product"} + ... ] + ... } + ... }, + ... "classifications": [ + ... {"task": "sentiment", "labels": ["positive", "negative"]} + ... ], + ... "relations": ["works_for", "founded_by"] + ... } + >>> schema = Schema.from_dict(schema_dict) + """ + from gliner2.inference.schema_model import SchemaInput + + # Validate input + validated = SchemaInput(**data) + + # Build schema using builder API + schema = cls() + + # Add entities + if validated.entities is not None: + schema.entities(validated.entities) + + # Add structures + if validated.structures is not None: + for struct_name, struct_input in validated.structures.items(): + builder = schema.structure(struct_name) + for field_input in struct_input.fields: + builder.field( + name=field_input.name, + dtype=field_input.dtype, + choices=field_input.choices, + description=field_input.description + ) + # Auto-finish the builder + builder._auto_finish() + + # Add classifications + if validated.classifications is not None: + for cls_input in validated.classifications: + schema.classification( + task=cls_input.task, + labels=cls_input.labels, + multi_label=cls_input.multi_label + ) + + # Add relations + if validated.relations is not None: + schema.relations(validated.relations) + + return schema + + @classmethod + def from_json(cls, json_str: str) -> 'Schema': + """Create a Schema from a JSON string. + + Args: + json_str: JSON string with schema definition + + Returns: + Schema: Constructed schema instance + + Raises: + ValidationError: If the input data is invalid + json.JSONDecodeError: If the JSON is malformed + + Example: + >>> schema_json = ''' + ... { + ... "entities": ["company", "person"], + ... "classifications": [ + ... {"task": "sentiment", "labels": ["positive", "negative"]} + ... ] + ... } + ... ''' + >>> schema = Schema.from_json(schema_json) + """ + data = json.loads(json_str) + return cls.from_dict(data) + + def to_dict(self) -> Dict[str, Any]: + """Convert schema to user-friendly dictionary format. + + Returns: + Dict: Schema in dictionary format compatible with from_dict() + + Example: + >>> schema = Schema() + >>> schema.entities(["company", "person"]) + >>> schema_dict = schema.to_dict() + >>> # schema_dict can be used with Schema.from_dict() + """ + result = {} + + # Export entities + if self.schema["entities"]: + # Check if we have descriptions + if self.schema["entity_descriptions"]: + result["entities"] = dict(self.schema["entity_descriptions"]) + else: + result["entities"] = list(self.schema["entities"].keys()) + + # Export structures + if self.schema["json_structures"]: + result["structures"] = {} + for struct_dict in self.schema["json_structures"]: + for struct_name, struct_fields in struct_dict.items(): + fields = [] + field_order = self._field_orders.get(struct_name, []) + + for field_name in field_order: + if field_name not in struct_fields: + continue + + field_key = f"{struct_name}.{field_name}" + metadata = self._field_metadata.get(field_key, {}) + + field_def = {"name": field_name} + + # Add dtype if not default + dtype = metadata.get("dtype", "list") + if dtype != "list": + field_def["dtype"] = dtype + + # Add choices if present + choices = metadata.get("choices") + if choices: + field_def["choices"] = choices + + # Add description if present + desc = self.schema.get("json_descriptions", {}).get(struct_name, {}).get(field_name) + if desc: + field_def["description"] = desc + + fields.append(field_def) + + result["structures"][struct_name] = {"fields": fields} + + # Export classifications + if self.schema["classifications"]: + result["classifications"] = [] + for cls_config in self.schema["classifications"]: + cls_def = { + "task": cls_config["task"], + "labels": cls_config["labels"] + } + if cls_config.get("multi_label", False): + cls_def["multi_label"] = True + result["classifications"].append(cls_def) + + # Export relations + if self.schema["relations"]: + result["relations"] = self._relation_order if self._relation_order else [ + list(rel_dict.keys())[0] for rel_dict in self.schema["relations"] + ] + + return result + + +# ============================================================================= +# Main GLiNER2 Class +# ============================================================================= + +class GLiNER2(Extractor): + """ + GLiNER2 Information Extraction Model. + + Provides efficient batch extraction with parallel preprocessing. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._schema_cache = {} + # OPT-11: Cached collator instance for inference + self._inference_collator = None + + @classmethod + def from_api(cls, api_key: str = None, api_base_url: str = None, + timeout: float = 30.0, max_retries: int = 3) -> 'GLiNER2API': + """Load from API instead of local model.""" + from gliner2.api_client import GLiNER2API + return GLiNER2API(api_key=api_key, api_base_url=api_base_url, + timeout=timeout, max_retries=max_retries) + + def create_schema(self) -> Schema: + """Create a new schema builder.""" + return Schema() + + # ========================================================================= + # Main Batch Extraction + # ========================================================================= + + @torch.inference_mode() + def batch_extract( + self, + texts: List[str], + schemas: Union[Schema, List[Schema], Dict, List[Dict]], + batch_size: int = 8, + threshold: float = 0.5, + num_workers: int = 0, + format_results: bool = True, + include_confidence: bool = False, + include_spans: bool = False + ) -> List[Dict[str, Any]]: + """ + Extract from multiple texts with parallel preprocessing. + + Args: + texts: List of input texts + schemas: Single schema or list of schemas + batch_size: Batch size for processing + threshold: Confidence threshold + num_workers: Workers for parallel preprocessing + format_results: Format output nicely + include_confidence: Include confidence scores + include_spans: Include character-level start/end positions + + Returns: + List of extraction results + """ + if not texts: + return [] + + self.eval() + self.processor.change_mode(is_training=False) + + # Normalize schemas + if isinstance(schemas, list): + if len(schemas) != len(texts): + raise ValueError(f"Schema count ({len(schemas)}) != text count ({len(texts)})") + schema_list = schemas + else: + schema_list = [schemas] * len(texts) + + # Build schema dicts and metadata + schema_dicts = [] + metadata_list = [] + + for schema in schema_list: + if hasattr(schema, 'build'): + schema_dict = schema.build() + # Extract classification task names + classification_tasks = [c["task"] for c in schema_dict.get("classifications", [])] + metadata = { + "field_metadata": schema._field_metadata, + "entity_metadata": schema._entity_metadata, + "relation_metadata": getattr(schema, '_relation_metadata', {}), + "field_orders": schema._field_orders, + "entity_order": schema._entity_order, + "relation_order": getattr(schema, '_relation_order', []), + "classification_tasks": classification_tasks + } + else: + schema_dict = schema + # Extract classification task names from dict schema + classification_tasks = [c["task"] for c in schema_dict.get("classifications", [])] + metadata = { + "field_metadata": {}, "entity_metadata": {}, + "relation_metadata": {}, "field_orders": {}, + "entity_order": [], "relation_order": [], + "classification_tasks": classification_tasks + } + + # Ensure classifications have true_label + for cls_config in schema_dict.get("classifications", []): + cls_config.setdefault("true_label", ["N/A"]) + + schema_dicts.append(schema_dict) + metadata_list.append(metadata) + + # OPT-9: Skip duplicate normalization β€” _collate_batch handles it + dataset = list(zip(texts, schema_dicts)) + + # OPT-11: Reuse cached collator instance + if self._inference_collator is None: + self._inference_collator = ExtractorCollator(self.processor, is_training=False) + collator = self._inference_collator + + # OPT-12: Skip DataLoader overhead for single-batch inputs + if len(dataset) <= batch_size and num_workers == 0: + batches = [collator(dataset)] + else: + batches = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + collate_fn=collator, + pin_memory=True if torch.cuda.is_available() else False, + ) + + # Process batches + all_results = [] + sample_idx = 0 + device = next(self.parameters()).device + + for batch in batches: + batch = batch.to(device) + batch_results = self._extract_from_batch( + batch, threshold, metadata_list[sample_idx:sample_idx + len(batch)], + include_confidence, include_spans + ) + + if format_results: + for i, result in enumerate(batch_results): + meta = metadata_list[sample_idx + i] + requested_relations = meta.get("relation_order", []) + classification_tasks = meta.get("classification_tasks", []) + batch_results[i] = self.format_results( + result, include_confidence, requested_relations, classification_tasks + ) + + all_results.extend(batch_results) + sample_idx += len(batch) + + return all_results + + def _extract_from_batch( + self, + batch: PreprocessedBatch, + threshold: float, + metadata_list: List[Dict], + include_confidence: bool, + include_spans: bool + ) -> List[Dict[str, Any]]: + """Extract from preprocessed batch.""" + # Encode batch + all_token_embs, all_schema_embs = self.processor.extract_embeddings_from_batch( + self.encoder( + input_ids=batch.input_ids, + attention_mask=batch.attention_mask + ).last_hidden_state, + batch.input_ids, + batch + ) + + results = [] + + for i in range(len(batch)): + try: + sample_result = self._extract_sample( + token_embs=all_token_embs[i], + schema_embs=all_schema_embs[i], + schema_tokens_list=batch.schema_tokens_list[i], + task_types=batch.task_types[i], + text_tokens=batch.text_tokens[i], + original_text=batch.original_texts[i], + schema=batch.original_schemas[i], + start_mapping=batch.start_mappings[i], + end_mapping=batch.end_mappings[i], + threshold=threshold, + metadata=metadata_list[i], + include_confidence=include_confidence, + include_spans=include_spans + ) + results.append(sample_result) + except Exception as e: + print(f"Error extracting sample {i}: {e}") + results.append({}) + + return results + + def _extract_sample( + self, + token_embs: torch.Tensor, + schema_embs: List[List[torch.Tensor]], + schema_tokens_list: List[List[str]], + task_types: List[str], + text_tokens: List[str], + original_text: str, + schema: Dict, + start_mapping: List[int], + end_mapping: List[int], + threshold: float, + metadata: Dict, + include_confidence: bool, + include_spans: bool + ) -> Dict[str, Any]: + """Extract from single sample.""" + results = {} + + # Compute span representations if needed + has_span_task = any(t != "classifications" for t in task_types) + span_info = None + if has_span_task and token_embs.numel() > 0: + span_info = self.compute_span_rep(token_embs) + + # Build classification field map + cls_fields = {} + for struct in schema.get("json_structures", []): + for parent, fields in struct.items(): + for fname, fval in fields.items(): + if isinstance(fval, dict) and "choices" in fval: + cls_fields[f"{parent}.{fname}"] = fval["choices"] + + # OPT-3: Use start_mapping length instead of re-tokenizing text + text_len = len(start_mapping) + + for i, (schema_tokens, task_type) in enumerate(zip(schema_tokens_list, task_types)): + if len(schema_tokens) < 4 or not schema_embs[i]: + continue + + schema_name = schema_tokens[2].split(" [DESCRIPTION] ")[0] + embs = torch.stack(schema_embs[i]) + + if task_type == "classifications": + self._extract_classification_result( + results, schema_name, schema, embs, schema_tokens + ) + else: + self._extract_span_result( + results, schema_name, task_type, embs, span_info, + schema_tokens, text_tokens, text_len, original_text, + start_mapping, end_mapping, threshold, metadata, + cls_fields, include_confidence, include_spans + ) + + return results + + def _extract_classification_result( + self, + results: Dict, + schema_name: str, + schema: Dict, + embs: torch.Tensor, + schema_tokens: List[str] + ): + """Extract classification result.""" + cls_config = next( + c for c in schema["classifications"] + if schema_tokens[2].startswith(c["task"]) + ) + + cls_embeds = embs[1:] + logits = self.classifier(cls_embeds).squeeze(-1) + + activation = cls_config.get("class_act", "auto") + is_multi = cls_config.get("multi_label", False) + + if activation == "sigmoid": + probs = torch.sigmoid(logits) + elif activation == "softmax": + probs = torch.softmax(logits, dim=-1) + else: + probs = torch.sigmoid(logits) if is_multi else torch.softmax(logits, dim=-1) + + labels = cls_config["labels"] + cls_threshold = cls_config.get("cls_threshold", 0.5) + + if is_multi: + chosen = [(labels[j], probs[j].item()) for j in range(len(labels)) if probs[j].item() >= cls_threshold] + if not chosen: + best = int(torch.argmax(probs).item()) + chosen = [(labels[best], probs[best].item())] + results[schema_name] = chosen + else: + best = int(torch.argmax(probs).item()) + results[schema_name] = (labels[best], probs[best].item()) + + def _extract_span_result( + self, + results: Dict, + schema_name: str, + task_type: str, + embs: torch.Tensor, + span_info: Dict, + schema_tokens: List[str], + text_tokens: List[str], + text_len: int, + original_text: str, + start_mapping: List[int], + end_mapping: List[int], + threshold: float, + metadata: Dict, + cls_fields: Dict, + include_confidence: bool, + include_spans: bool + ): + """Extract span-based results.""" + # Get field names + field_names = [] + for j in range(len(schema_tokens) - 1): + if schema_tokens[j] in ("[E]", "[C]", "[R]"): + field_names.append(schema_tokens[j + 1]) + + if not field_names: + results[schema_name] = [] if schema_name == "entities" else {} + return + + # Predict count + count_logits = self.count_pred(embs[0].unsqueeze(0)) + pred_count = int(count_logits.argmax(dim=1).item()) + + if pred_count <= 0 or span_info is None: + if schema_name == "entities": + results[schema_name] = [] + elif task_type == "relations": + results[schema_name] = [] + else: + results[schema_name] = {} + return + + # Get span scores + struct_proj = self.count_embed(embs[1:], pred_count) + span_scores = torch.sigmoid( + torch.einsum("lkd,bpd->bplk", span_info["span_rep"], struct_proj) + ) + + # Extract based on type + if schema_name == "entities": + results[schema_name] = self._extract_entities( + field_names, span_scores, text_len, text_tokens, + original_text, start_mapping, end_mapping, + threshold, metadata, include_confidence, include_spans + ) + elif task_type == "relations": + results[schema_name] = self._extract_relations( + schema_name, field_names, span_scores, pred_count, + text_len, text_tokens, original_text, start_mapping, end_mapping, + threshold, metadata, include_confidence, include_spans + ) + else: + results[schema_name] = self._extract_structures( + schema_name, field_names, span_scores, pred_count, + text_len, text_tokens, original_text, start_mapping, end_mapping, + threshold, metadata, cls_fields, include_confidence, include_spans + ) + + def _extract_entities( + self, + entity_names: List[str], + span_scores: torch.Tensor, + text_len: int, + text_tokens: List[str], + text: str, + start_map: List[int], + end_map: List[int], + threshold: float, + metadata: Dict, + include_confidence: bool, + include_spans: bool + ) -> List[Dict]: + """Extract entity results.""" + scores = span_scores[0, :, -text_len:] + entity_results = OrderedDict() + + for name in metadata.get("entity_order", entity_names): + if name not in entity_names: + continue + + idx = entity_names.index(name) + meta = metadata.get("entity_metadata", {}).get(name, {}) + meta_threshold = meta.get("threshold") + ent_threshold = meta_threshold if meta_threshold is not None else threshold + dtype = meta.get("dtype", "list") + + spans = self._find_spans( + scores[idx], ent_threshold, text_len, text, + start_map, end_map + ) + + if dtype == "list": + entity_results[name] = self._format_spans(spans, include_confidence, include_spans) + else: + if spans: + text_val, conf, char_start, char_end = spans[0] + + if include_spans and include_confidence: + entity_results[name] = { + "text": text_val, + "confidence": conf, + "start": char_start, + "end": char_end + } + elif include_spans: + entity_results[name] = { + "text": text_val, + "start": char_start, + "end": char_end + } + elif include_confidence: + entity_results[name] = {"text": text_val, "confidence": conf} + else: + entity_results[name] = text_val + else: + entity_results[name] = "" if not include_spans and not include_confidence else None + + return [entity_results] if entity_results else [] + + def _extract_relations( + self, + rel_name: str, + field_names: List[str], + span_scores: torch.Tensor, + count: int, + text_len: int, + text_tokens: List[str], + text: str, + start_map: List[int], + end_map: List[int], + threshold: float, + metadata: Dict, + include_confidence: bool, + include_spans: bool + ) -> List[Union[Tuple[str, str], Dict]]: + """Extract relation results with optional confidence and position info.""" + instances = [] + + rel_threshold = threshold + if rel_name in metadata.get("relation_metadata", {}): + meta_threshold = metadata["relation_metadata"][rel_name].get("threshold") + rel_threshold = meta_threshold if meta_threshold is not None else threshold + + ordered_fields = metadata.get("field_orders", {}).get(rel_name, field_names) + + for inst in range(count): + scores = span_scores[inst, :, -text_len:] + values = [] + field_data = [] # Store full data for each field + + for fname in ordered_fields: + if fname not in field_names: + continue + fidx = field_names.index(fname) + spans = self._find_spans( + scores[fidx], rel_threshold, text_len, text, + start_map, end_map + ) + + if spans: + text_val, conf, char_start, char_end = spans[0] + values.append(text_val) + field_data.append({ + "text": text_val, + "confidence": conf, + "start": char_start, + "end": char_end + }) + else: + values.append(None) + field_data.append(None) + + if len(values) == 2 and values[0] and values[1]: + # Format based on flags + if include_spans and include_confidence: + instances.append({ + "head": field_data[0], + "tail": field_data[1] + }) + elif include_spans: + instances.append({ + "head": {"text": field_data[0]["text"], "start": field_data[0]["start"], "end": field_data[0]["end"]}, + "tail": {"text": field_data[1]["text"], "start": field_data[1]["start"], "end": field_data[1]["end"]} + }) + elif include_confidence: + instances.append({ + "head": {"text": field_data[0]["text"], "confidence": field_data[0]["confidence"]}, + "tail": {"text": field_data[1]["text"], "confidence": field_data[1]["confidence"]} + }) + else: + # Original tuple format for backward compatibility + instances.append((values[0], values[1])) + + return instances + + def _extract_structures( + self, + struct_name: str, + field_names: List[str], + span_scores: torch.Tensor, + count: int, + text_len: int, + text_tokens: List[str], + text: str, + start_map: List[int], + end_map: List[int], + threshold: float, + metadata: Dict, + cls_fields: Dict, + include_confidence: bool, + include_spans: bool + ) -> List[Dict]: + """Extract structure results with optional position tracking.""" + instances = [] + ordered_fields = metadata.get("field_orders", {}).get(struct_name, field_names) + + for inst in range(count): + scores = span_scores[inst, :, -text_len:] + instance = OrderedDict() + + for fname in ordered_fields: + if fname not in field_names: + continue + + fidx = field_names.index(fname) + field_key = f"{struct_name}.{fname}" + meta = metadata.get("field_metadata", {}).get(field_key, {}) + meta_threshold = meta.get("threshold") + field_threshold = meta_threshold if meta_threshold is not None else threshold + dtype = meta.get("dtype", "list") + validators = meta.get("validators", []) + + if field_key in cls_fields: + # Classification field - no span positions needed + choices = cls_fields[field_key] + prefix_scores = span_scores[inst, fidx, :-text_len] + + if dtype == "list": + selected = [] + seen = set() + for choice in choices: + if choice in seen: + continue + idx = self._find_choice_idx(choice, text_tokens[:-text_len]) + if idx >= 0 and idx < prefix_scores.shape[0]: + score = prefix_scores[idx, 0].item() + if score >= field_threshold: + if include_confidence: + selected.append({"text": choice, "confidence": score}) + else: + selected.append(choice) + seen.add(choice) + instance[fname] = selected + else: + best = None + best_score = -1.0 + for choice in choices: + idx = self._find_choice_idx(choice, text_tokens[:-text_len]) + if idx >= 0 and idx < prefix_scores.shape[0]: + score = prefix_scores[idx, 0].item() + if score > best_score: + best_score = score + best = choice + if best and best_score >= field_threshold: + if include_confidence: + instance[fname] = {"text": best, "confidence": best_score} + else: + instance[fname] = best + else: + instance[fname] = None + else: + # Regular span field - track positions + spans = self._find_spans( + scores[fidx], field_threshold, text_len, text, + start_map, end_map + ) + + if validators: + spans = [s for s in spans if all(v.validate(s[0]) for v in validators)] + + if dtype == "list": + instance[fname] = self._format_spans(spans, include_confidence, include_spans) + else: + if spans: + text_val, conf, char_start, char_end = spans[0] + + if include_spans and include_confidence: + instance[fname] = { + "text": text_val, + "confidence": conf, + "start": char_start, + "end": char_end + } + elif include_spans: + instance[fname] = { + "text": text_val, + "start": char_start, + "end": char_end + } + elif include_confidence: + instance[fname] = {"text": text_val, "confidence": conf} + else: + instance[fname] = text_val + else: + instance[fname] = None + + # Only add if has content + if any(v is not None and v != [] for v in instance.values()): + instances.append(instance) + + return instances + + def _find_spans( + self, + scores: torch.Tensor, + threshold: float, + text_len: int, + text: str, + start_map: List[int], + end_map: List[int] + ) -> List[Tuple[str, float, int, int]]: + """Find valid spans above threshold. Returns (text, confidence, char_start, char_end).""" + valid = torch.where(scores >= threshold) + starts, widths = valid + + spans = [] + for start, width in zip(starts.tolist(), widths.tolist()): + end = start + width + 1 + if 0 <= start < text_len and end <= text_len: + try: + char_start = start_map[start] + char_end = end_map[end - 1] + text_span = text[char_start:char_end].strip() + except (IndexError, KeyError): + continue + + if text_span: + conf = scores[start, width].item() + spans.append((text_span, conf, char_start, char_end)) + + return spans + + def _format_spans( + self, + spans: List[Tuple], + include_confidence: bool, + include_spans: bool = False + ) -> Union[List[str], List[Dict], List[Tuple]]: + """Format spans with overlap removal and optional position info.""" + if not spans: + return [] + + sorted_spans = sorted(spans, key=lambda x: x[1], reverse=True) + selected = [] + + for text, conf, start, end in sorted_spans: + overlap = any(not (end <= s[2] or start >= s[3]) for s in selected) + if not overlap: + selected.append((text, conf, start, end)) + + # Format based on flags + if include_spans and include_confidence: + return [{"text": s[0], "confidence": s[1], "start": s[2], "end": s[3]} for s in selected] + elif include_spans: + return [{"text": s[0], "start": s[2], "end": s[3]} for s in selected] + elif include_confidence: + return [{"text": s[0], "confidence": s[1]} for s in selected] + else: + return [s[0] for s in selected] + + def _find_choice_idx(self, choice: str, tokens: List[str]) -> int: + """Find index of choice in tokens.""" + choice_lower = choice.lower() + for i, tok in enumerate(tokens): + if tok.lower() == choice_lower or choice_lower in tok.lower(): + return i + return -1 + + # ========================================================================= + # Result Formatting + # ========================================================================= + + def format_results( + self, + results: Dict, + include_confidence: bool = False, + requested_relations: List[str] = None, + classification_tasks: List[str] = None + ) -> Dict[str, Any]: + """Format extraction results.""" + formatted = {} + relations = {} + requested_relations = requested_relations or [] + classification_tasks = classification_tasks or [] + + for key, value in results.items(): + # Check if this is a classification task (takes priority) + is_classification = key in classification_tasks + + # Check if this is a relation + is_relation = False + + if not is_classification: + # Check if key is in requested_relations (this takes priority) + if key in requested_relations: + is_relation = True + # Otherwise, check the value structure + elif isinstance(value, list) and len(value) > 0: + # Check for tuple format: [(head, tail), ...] + if isinstance(value[0], tuple) and len(value[0]) == 2: + is_relation = True + # Check for dict format with head/tail keys: [{"head": ..., "tail": ...}, ...] + elif isinstance(value[0], dict) and "head" in value[0] and "tail" in value[0]: + is_relation = True + + if is_classification: + # This is a classification task - format and add to formatted dict directly + if isinstance(value, list): + # Multi-label classification + if include_confidence: + formatted[key] = [{"label": l, "confidence": c} for l, c in value] + else: + formatted[key] = [l for l, _ in value] + elif isinstance(value, tuple): + # Single-label classification + label, conf = value + formatted[key] = {"label": label, "confidence": conf} if include_confidence else label + else: + formatted[key] = value + elif is_relation: + # This is a relation - store in relations dict, not formatted + # Relations should always be lists, but handle edge cases defensively + if isinstance(value, list): + relations[key] = value + else: + # Unexpected non-list value for relation - convert to empty list + relations[key] = [] + elif isinstance(value, list): + if len(value) == 0: + if key == "entities": + formatted[key] = {} + else: + formatted[key] = value + elif isinstance(value[0], dict): + if key == "entities": + formatted[key] = self._format_entity_dict(value[0], include_confidence) + else: + formatted[key] = [self._format_struct(v, include_confidence) for v in value] + elif isinstance(value[0], tuple): + if include_confidence: + formatted[key] = [{"label": l, "confidence": c} for l, c in value] + else: + formatted[key] = [l for l, _ in value] + else: + formatted[key] = value + elif isinstance(value, tuple): + label, conf = value + formatted[key] = {"label": label, "confidence": conf} if include_confidence else label + elif isinstance(value, dict): + formatted[key] = self._format_struct(value, include_confidence) + else: + formatted[key] = value + + # Add all requested relations (including empty ones) + for rel in requested_relations: + if rel not in relations: + relations[rel] = [] + + # Only add relation_extraction if we have relations + if relations: + formatted["relation_extraction"] = relations + + return formatted + + def _format_entity_dict(self, entities: Dict, include_confidence: bool) -> Dict: + formatted = {} + for name, spans in entities.items(): + if isinstance(spans, list): + unique = [] + seen = set() + for span in spans: + if isinstance(span, tuple): + text, conf, start, end = span + if text and text.lower() not in seen: + seen.add(text.lower()) + unique.append({"text": text, "confidence": conf} if include_confidence else text) + elif isinstance(span, dict): + # Handle dict format (with confidence/spans) + text = span.get("text", "") + if text and text.lower() not in seen: + seen.add(text.lower()) + unique.append(span) + else: + # Handle string format + if span and span.lower() not in seen: + seen.add(span.lower()) + unique.append(span) + formatted[name] = unique + elif isinstance(spans, tuple): + text, conf, _, _ = spans + formatted[name] = {"text": text, "confidence": conf} if include_confidence and text else text + else: + formatted[name] = spans or None + return formatted + + def _format_struct(self, struct: Dict, include_confidence: bool) -> Dict: + formatted = {} + for field, value in struct.items(): + if isinstance(value, list): + unique = [] + seen = set() + for v in value: + if isinstance(v, tuple): + text, conf, _, _ = v + if text and text.lower() not in seen: + seen.add(text.lower()) + unique.append({"text": text, "confidence": conf} if include_confidence else text) + elif isinstance(v, dict): + # Handle dict format (with confidence/spans) + text = v.get("text", "") + if text and text.lower() not in seen: + seen.add(text.lower()) + unique.append(v) + else: + # Handle string format + if v and v.lower() not in seen: + seen.add(v.lower()) + unique.append(v) + formatted[field] = unique + elif isinstance(value, tuple): + text, conf, _, _ = value + formatted[field] = {"text": text, "confidence": conf} if include_confidence and text else text + elif value: + formatted[field] = value + else: + formatted[field] = None + return formatted + + # ========================================================================= + # Convenience Methods (route through batch) + # ========================================================================= + + def extract(self, text: str, schema, threshold: float = 0.5, + format_results: bool = True, include_confidence: bool = False, + include_spans: bool = False) -> Dict: + """Extract from single text.""" + return self.batch_extract([text], schema, 1, threshold, 0, format_results, include_confidence, include_spans)[0] + + def extract_entities(self, text: str, entity_types, threshold: float = 0.5, + format_results: bool = True, include_confidence: bool = False, + include_spans: bool = False) -> Dict: + """Extract entities from text.""" + schema = self.create_schema().entities(entity_types) + return self.extract(text, schema, threshold, format_results, include_confidence, include_spans) + + def batch_extract_entities(self, texts: List[str], entity_types, batch_size: int = 8, + threshold: float = 0.5, format_results: bool = True, + include_confidence: bool = False, include_spans: bool = False) -> List[Dict]: + """Batch extract entities.""" + schema = self.create_schema().entities(entity_types) + return self.batch_extract(texts, schema, batch_size, threshold, 0, format_results, include_confidence, include_spans) + + def classify_text(self, text: str, tasks: Dict, threshold: float = 0.5, + format_results: bool = True, include_confidence: bool = False, + include_spans: bool = False) -> Dict: + """Classify text.""" + schema = self.create_schema() + for name, config in tasks.items(): + if isinstance(config, dict) and "labels" in config: + cfg = config.copy() + labels = cfg.pop("labels") + schema.classification(name, labels, **cfg) + else: + schema.classification(name, config) + return self.extract(text, schema, threshold, format_results, include_confidence, include_spans) + + def batch_classify_text(self, texts: List[str], tasks: Dict, batch_size: int = 8, + threshold: float = 0.5, format_results: bool = True, + include_confidence: bool = False, include_spans: bool = False) -> List[Dict]: + """Batch classify texts.""" + schema = self.create_schema() + for name, config in tasks.items(): + if isinstance(config, dict) and "labels" in config: + cfg = config.copy() + labels = cfg.pop("labels") + schema.classification(name, labels, **cfg) + else: + schema.classification(name, config) + return self.batch_extract(texts, schema, batch_size, threshold, 0, format_results, include_confidence, include_spans) + + def extract_json(self, text: str, structures: Dict, threshold: float = 0.5, + format_results: bool = True, include_confidence: bool = False, + include_spans: bool = False) -> Dict: + """Extract structured data.""" + schema = self.create_schema() + for parent, fields in structures.items(): + builder = schema.structure(parent) + for spec in fields: + name, dtype, choices, desc = self._parse_field_spec(spec) + builder.field(name, dtype=dtype, choices=choices, description=desc) + return self.extract(text, schema, threshold, format_results, include_confidence, include_spans) + + def batch_extract_json(self, texts: List[str], structures: Dict, batch_size: int = 8, + threshold: float = 0.5, format_results: bool = True, + include_confidence: bool = False, include_spans: bool = False) -> List[Dict]: + """Batch extract structured data.""" + schema = self.create_schema() + for parent, fields in structures.items(): + builder = schema.structure(parent) + for spec in fields: + name, dtype, choices, desc = self._parse_field_spec(spec) + builder.field(name, dtype=dtype, choices=choices, description=desc) + return self.batch_extract(texts, schema, batch_size, threshold, 0, format_results, include_confidence, include_spans) + + def extract_relations(self, text: str, relation_types, threshold: float = 0.5, + format_results: bool = True, include_confidence: bool = False, + include_spans: bool = False) -> Dict: + """Extract relations.""" + schema = self.create_schema().relations(relation_types) + return self.extract(text, schema, threshold, format_results, include_confidence, include_spans) + + def batch_extract_relations(self, texts: List[str], relation_types, batch_size: int = 8, + threshold: float = 0.5, format_results: bool = True, + include_confidence: bool = False, include_spans: bool = False) -> List[Dict]: + """Batch extract relations.""" + schema = self.create_schema().relations(relation_types) + return self.batch_extract(texts, schema, batch_size, threshold, 0, format_results, include_confidence, include_spans) + + def _parse_field_spec(self, spec: Union[str, Dict]) -> Tuple[str, str, Optional[List[str]], Optional[str]]: + """Parse field specification string or dictionary. + + Format: "name::dtype::choices::description" where all parts after name are optional. + - dtype: 'str' for single value, 'list' for multiple values + - choices: [option1|option2|...] for enumerated options + - description: free text description + + Examples: + "restaurant::str::Restaurant name" + "seating::[indoor|outdoor|bar]::Seating preference" + "dietary::[vegetarian|vegan|gluten-free|none]::list::Dietary restrictions" + """ + if isinstance(spec, dict): + return ( + spec.get("name", ""), + spec.get("dtype", "list"), + spec.get("choices"), + spec.get("description") + ) + + parts = spec.split('::') + name = parts[0] + dtype, choices, desc = "list", None, None + dtype_explicitly_set = False + + if len(parts) == 1: + return name, dtype, choices, desc + + for part in parts[1:]: + if part in ['str', 'list']: + dtype = part + dtype_explicitly_set = True + elif part.startswith('[') and part.endswith(']'): + choices = [c.strip() for c in part[1:-1].split('|')] + # Only default to "str" if dtype wasn't explicitly set + if not dtype_explicitly_set: + dtype = "str" + else: + desc = part + + return name, dtype, choices, desc + + +# Aliases +BuilderExtractor = GLiNER2 +SchemaBuilder = Schema +JsonStructBuilder = StructureBuilder \ No newline at end of file diff --git a/packages/GLiNER2/gliner2/inference/schema_model.py b/packages/GLiNER2/gliner2/inference/schema_model.py new file mode 100644 index 0000000..06c383a --- /dev/null +++ b/packages/GLiNER2/gliner2/inference/schema_model.py @@ -0,0 +1,191 @@ +""" +Pydantic models for validating schema input from JSON/dict. + +This module provides validation models for creating GLiNER2 schemas +from JSON or dictionary inputs. +""" + +from typing import Any, Dict, List, Literal, Optional, Union +from pydantic import BaseModel, Field, field_validator, model_validator + + +class FieldInput(BaseModel): + """Validates a single structure field. + + Args: + name: Field name + dtype: Data type - 'str' for single value, 'list' for multiple values + choices: Optional list of valid choices for classification-style fields + description: Optional description of the field + """ + name: str = Field(..., min_length=1, description="Field name") + dtype: Literal["str", "list"] = Field(default="list", description="Data type") + choices: Optional[List[str]] = Field(default=None, description="Valid choices") + description: Optional[str] = Field(default=None, description="Field description") + + @field_validator('choices') + @classmethod + def validate_choices(cls, v: Optional[List[str]]) -> Optional[List[str]]: + """Ensure choices list is not empty if provided.""" + if v is not None and len(v) == 0: + raise ValueError("choices must contain at least one option") + return v + + +class StructureInput(BaseModel): + """Validates a structure block. + + Args: + fields: List of field definitions + """ + fields: List[FieldInput] = Field(..., min_length=1, description="List of fields") + + +class ClassificationInput(BaseModel): + """Validates a classification task. + + Args: + task: Task name + labels: List of classification labels + multi_label: Whether multiple labels can be selected + """ + task: str = Field(..., min_length=1, description="Task name") + labels: List[str] = Field(..., min_length=2, description="Classification labels") + multi_label: bool = Field(default=False, description="Multi-label classification") + + @field_validator('labels') + @classmethod + def validate_labels(cls, v: List[str]) -> List[str]: + """Ensure labels are unique and non-empty.""" + if len(v) != len(set(v)): + raise ValueError("labels must be unique") + if any(not label.strip() for label in v): + raise ValueError("labels cannot be empty strings") + return v + + +class SchemaInput(BaseModel): + """Root schema validation model. + + Args: + entities: List of entity types or dict mapping types to descriptions + structures: Dict mapping structure names to structure definitions + classifications: List of classification task definitions + relations: List of relation types or dict mapping types to config + """ + entities: Optional[Union[List[str], Dict[str, str]]] = Field( + default=None, + description="Entity types" + ) + structures: Optional[Dict[str, StructureInput]] = Field( + default=None, + description="Structure definitions" + ) + classifications: Optional[List[ClassificationInput]] = Field( + default=None, + description="Classification tasks" + ) + relations: Optional[Union[List[str], Dict[str, Dict[str, Any]]]] = Field( + default=None, + description="Relation types" + ) + + @field_validator('entities') + @classmethod + def validate_entities( + cls, + v: Optional[Union[List[str], Dict[str, str]]] + ) -> Optional[Union[List[str], Dict[str, str]]]: + """Validate entities format.""" + if v is None: + return v + + if isinstance(v, list): + if len(v) == 0: + raise ValueError("entities list cannot be empty") + if any(not entity.strip() for entity in v): + raise ValueError("entity names cannot be empty strings") + if len(v) != len(set(v)): + raise ValueError("entity names must be unique") + elif isinstance(v, dict): + if len(v) == 0: + raise ValueError("entities dict cannot be empty") + if any(not key.strip() for key in v.keys()): + raise ValueError("entity names cannot be empty strings") + + return v + + @field_validator('structures') + @classmethod + def validate_structures( + cls, + v: Optional[Dict[str, StructureInput]] + ) -> Optional[Dict[str, StructureInput]]: + """Validate structures format.""" + if v is None: + return v + + if len(v) == 0: + raise ValueError("structures dict cannot be empty") + if any(not key.strip() for key in v.keys()): + raise ValueError("structure names cannot be empty strings") + + return v + + @field_validator('classifications') + @classmethod + def validate_classifications( + cls, + v: Optional[List[ClassificationInput]] + ) -> Optional[List[ClassificationInput]]: + """Validate classifications format.""" + if v is None: + return v + + if len(v) == 0: + raise ValueError("classifications list cannot be empty") + + # Check for duplicate task names + task_names = [cls_task.task for cls_task in v] + if len(task_names) != len(set(task_names)): + raise ValueError("classification task names must be unique") + + return v + + @field_validator('relations') + @classmethod + def validate_relations( + cls, + v: Optional[Union[List[str], Dict[str, Dict[str, Any]]]] + ) -> Optional[Union[List[str], Dict[str, Dict[str, Any]]]]: + """Validate relations format.""" + if v is None: + return v + + if isinstance(v, list): + if len(v) == 0: + raise ValueError("relations list cannot be empty") + if any(not rel.strip() for rel in v): + raise ValueError("relation names cannot be empty strings") + if len(v) != len(set(v)): + raise ValueError("relation names must be unique") + elif isinstance(v, dict): + if len(v) == 0: + raise ValueError("relations dict cannot be empty") + if any(not key.strip() for key in v.keys()): + raise ValueError("relation names cannot be empty strings") + + return v + + @model_validator(mode='after') + def validate_at_least_one_section(self) -> 'SchemaInput': + """Ensure at least one section is provided.""" + if all( + getattr(self, field) is None + for field in ['entities', 'structures', 'classifications', 'relations'] + ): + raise ValueError( + "At least one of entities, structures, classifications, " + "or relations must be provided" + ) + return self diff --git a/packages/GLiNER2/gliner2/layers.py b/packages/GLiNER2/gliner2/layers.py new file mode 100644 index 0000000..71439b6 --- /dev/null +++ b/packages/GLiNER2/gliner2/layers.py @@ -0,0 +1,249 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def create_mlp(input_dim, intermediate_dims, output_dim, dropout=0.1, activation="gelu", add_layer_norm=False): + """ + Creates a multi-layer perceptron (MLP) with specified dimensions and activation functions. + """ + activation_mapping = { + "relu": nn.ReLU, + "tanh": nn.Tanh, + "sigmoid": nn.Sigmoid, + "leaky_relu": nn.LeakyReLU, + "gelu": nn.GELU + } + layers = [] + in_dim = input_dim + for dim in intermediate_dims: + layers.append(nn.Linear(in_dim, dim)) + if add_layer_norm: + layers.append(nn.LayerNorm(dim)) + layers.append(activation_mapping[activation]()) + if dropout > 0: + layers.append(nn.Dropout(dropout)) + in_dim = dim + layers.append(nn.Linear(in_dim, output_dim)) + return nn.Sequential(*layers) + + +class DownscaledTransformer(nn.Module): + def __init__(self, input_size, hidden_size, num_heads=4, num_layers=2, dropout=0.1): + """ + Initializes a downscaled transformer with specified parameters. + """ + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_layers = num_layers + + self.in_projector = nn.Linear(input_size, hidden_size) + + encoder = nn.TransformerEncoderLayer( + d_model=hidden_size, + nhead=num_heads, + dim_feedforward=hidden_size * 2, + dropout=dropout, + batch_first=True + ) + + self.transformer = nn.TransformerEncoder(encoder, num_layers=num_layers) + + self.out_projector = create_mlp( + input_dim=hidden_size + input_size, + intermediate_dims=[input_size, input_size], + output_dim=input_size, + dropout=0., + activation="relu", + add_layer_norm=False + ) + + def forward(self, x): + """ + Args: + x (Tensor): Input tensor of shape (L, M, input_size). + Returns: + Tensor: Output tensor of shape (L, M, input_size). + """ + original_x = x + # Project input to hidden size. + x = self.in_projector(x) + # Apply transformer encoder.xx + x = self.transformer(x) + # Concatenate original input with transformer output. + x = torch.cat([x, original_x], dim=-1) + # Project back to input size. + x = self.out_projector(x) + return x + + +class CountLSTM(nn.Module): + def __init__(self, hidden_size, max_count=20): + """ + Initializes the module with a learned positional embedding for count steps and a GRU. + """ + super().__init__() + self.hidden_size = hidden_size + self.max_count = max_count + # Learned positional embeddings for count steps: shape (max_count, hidden_size) + self.pos_embedding = nn.Embedding(max_count, hidden_size) + # Use a GRU layer; input shape is (seq_len, batch, input_size) + self.gru = nn.GRU(input_size=hidden_size, hidden_size=hidden_size) + # Projector layer: combines GRU output with original embeddings. + self.projector = create_mlp( + input_dim=hidden_size * 2, + intermediate_dims=[hidden_size * 4], + output_dim=hidden_size, + dropout=0., + activation="relu", + add_layer_norm=False + ) + + def forward(self, pc_emb: torch.Tensor, gold_count_val: int) -> torch.Tensor: + """ + Args: + pc_emb (Tensor): Field embeddings of shape (M, hidden_size). + gold_count_val (int): Predicted count value (number of steps). + Returns: + Tensor: Count-aware structure embeddings of shape (gold_count_val, M, hidden_size). + """ + M, D = pc_emb.shape + # Cap gold_count_val by max_count. + gold_count_val = min(gold_count_val, self.max_count) + # Create a sequence of count indices: shape (gold_count_val,) + count_indices = torch.arange(gold_count_val, device=pc_emb.device) + # Get positional embeddings for each count: (gold_count_val, hidden_size) + pos_seq = self.pos_embedding(count_indices) + # Expand pos_seq over the batch dimension: (gold_count_val, M, hidden_size) + pos_seq = pos_seq.unsqueeze(1).expand(gold_count_val, M, D) + # Initialize the GRU hidden state with the field embeddings. + h0 = pc_emb.unsqueeze(0) # shape: (1, M, hidden_size) + # Run the GRU over the count sequence. + output, _ = self.gru(pos_seq, h0) + # Concatenate the GRU outputs with the original field embeddings. + return self.projector(torch.cat([output, pc_emb.unsqueeze(0).expand_as(output)], dim=-1)) + + +class CountLSTMv2(nn.Module): + def __init__(self, hidden_size, max_count=20): + super().__init__() + self.hidden_size = hidden_size + self.max_count = max_count + self.pos_embedding = nn.Embedding(max_count, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size) + self.transformer = DownscaledTransformer( + hidden_size, + hidden_size=128, + num_heads=4, + num_layers=2, + dropout=0.1, + ) + + # NOTE: gold_count_val is now a 0-D Tensor, not a Python int + def forward(self, pc_emb: torch.Tensor, gold_count_val: int) -> torch.Tensor: + M, D = pc_emb.size() # symbolic sizes + + # clamp without dropping to Python + gold_count_val = min(gold_count_val, self.max_count) + + # build the *full* index vector once, then slice – ONNX supports both ops + full_idx = torch.arange(self.max_count, device=pc_emb.device) + count_idx = full_idx[:gold_count_val] # (gold_count_val,) + + pos_seq = self.pos_embedding(count_idx) # (gold_count_val, D) + pos_seq = pos_seq.unsqueeze(1).expand(-1, M, -1) # (gold_count_val, M, D) + + h0 = pc_emb.unsqueeze(0) # (1, M, D) + output, _ = self.gru(pos_seq, h0) # (gold_count_val, M, D) + + pc_broadcast = pc_emb.unsqueeze(0).expand_as(output) + return self.transformer(output + pc_broadcast) + + +class CountLSTMoE(nn.Module): + """ + Count-aware module with a Mixture-of-Experts projector. + + Args + ---- + hidden_size : int + Model dimensionality (D). + max_count : int + Maximum # count steps L. + n_experts : int, optional + Number of FFN experts (default = 4). + ffn_mult : int, optional + Expansion factor for expert bottleneck (default = 2 β†’ inner = 2 D). + dropout : float, optional + Drop-out used inside expert FFNs. + """ + + def __init__(self, + hidden_size: int, + max_count: int = 20, + n_experts: int = 4, + ffn_mult: int = 2, + dropout: float = 0.1): + super().__init__() + self.hidden_size, self.max_count, self.n_experts = ( + hidden_size, max_count, n_experts) + + # ───── positional encoding + recurrent core ───── + self.pos_embedding = nn.Embedding(max_count, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size) + + # ───── expert parameters (all packed in two tensors) ───── + inner = hidden_size * ffn_mult + # W1 : [E, D, inner]   b1 : [E, inner] + self.w1 = nn.Parameter(torch.empty(n_experts, hidden_size, inner)) + self.b1 = nn.Parameter(torch.zeros(n_experts, inner)) + # W2 : [E, inner, D]  b2 : [E, D] + self.w2 = nn.Parameter(torch.empty(n_experts, inner, hidden_size)) + self.b2 = nn.Parameter(torch.zeros(n_experts, hidden_size)) + + # better than default init for large mats + nn.init.xavier_uniform_(self.w1) + nn.init.xavier_uniform_(self.w2) + + self.dropout = nn.Dropout(dropout) + + # ───── router / gating network ───── + self.router = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.GELU(), + nn.Linear(hidden_size, n_experts), # logits + nn.Softmax(dim=-1), # gates sum-to-1 + ) + + # --------------------------------------------------- + def forward(self, pc_emb: torch.Tensor, gold_count_val: int) -> torch.Tensor: + """ + pc_emb : [M, D] field embeddings + gold_count_val : int (# count steps to unroll) + returns : [L, M, D] count-aware embeddings + """ + M, D = pc_emb.shape + L = min(gold_count_val, self.max_count) + + idx = torch.arange(L, device=pc_emb.device) + pos_seq = self.pos_embedding(idx).unsqueeze(1).expand(L, M, D) + + h0 = pc_emb.unsqueeze(0) # [1, M, D] + h, _ = self.gru(pos_seq, h0) # [L, M, D] + + # ───── routing / gating ───── + gates = self.router(h) # [L, M, E] + + # ───── expert FFN: run *all* experts in parallel ───── + # 1st linear + x = torch.einsum('lmd,edh->lmeh', h, self.w1) + self.b1 # [L, M, E, inner] + x = F.gelu(x) + x = self.dropout(x) + # 2nd linear + x = torch.einsum('lmeh,ehd->lmed', x, self.w2) + self.b2 # [L, M, E, D] + + # ───── mixture weighted by gates ───── + out = (gates.unsqueeze(-1) * x).sum(dim=2) # [L, M, D] + return out diff --git a/packages/GLiNER2/gliner2/model.py b/packages/GLiNER2/gliner2/model.py new file mode 100644 index 0000000..ea36b94 --- /dev/null +++ b/packages/GLiNER2/gliner2/model.py @@ -0,0 +1,692 @@ +""" +GLiNER2 Extractor Model with Optimized Batch Processing + +This module contains the core Extractor model that accepts PreprocessedBatch +directly for efficient GPU-only forward passes. +""" + +import os +import tempfile +from typing import Dict, List, Any, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from gliner.modeling.span_rep import SpanRepLayer +from gliner2.layers import CountLSTMoE, CountLSTM, create_mlp, CountLSTMv2 +from gliner2.processor import SchemaTransformer, PreprocessedBatch, SamplingConfig +from safetensors.torch import save_file, load_file +from transformers import ( + PretrainedConfig, + PreTrainedModel, + AutoModel, + AutoConfig, + AutoTokenizer, +) + + +class ExtractorConfig(PretrainedConfig): + """Configuration for the Extractor model.""" + model_type = "extractor" + + def __init__( + self, + model_name: str = "bert-base-uncased", + max_width: int = 8, + counting_layer: str = "count_lstm", + token_pooling: str = "first", + **kwargs + ): + super().__init__(**kwargs) + self.model_name = model_name + self.max_width = max_width + self.counting_layer = counting_layer + self.token_pooling = token_pooling + + +class Extractor(PreTrainedModel): + """ + GLiNER2 Extractor Model. + + This model accepts PreprocessedBatch for efficient training. + Use processor.collate_fn_train() to create batches. + + Example: + >>> processor = SchemaTransformer(model_name) + >>> model = Extractor.from_pretrained(repo_id) + >>> + >>> # Training + >>> loader = DataLoader(dataset, collate_fn=processor.collate_fn_train) + >>> for batch in loader: + ... batch = batch.to(device) + ... loss = model(batch)["total_loss"] + """ + config_class = ExtractorConfig + + def __init__(self, config: ExtractorConfig, encoder_config=None, tokenizer=None): + super().__init__(config) + self.config = config + self.max_width = config.max_width + + # Initialize processor + if tokenizer is not None: + self.processor = SchemaTransformer( + tokenizer=tokenizer, + token_pooling=config.token_pooling + ) + else: + self.processor = SchemaTransformer( + config.model_name, + token_pooling=config.token_pooling + ) + + # Load encoder + if encoder_config is not None: + self.encoder = AutoModel.from_config(encoder_config, trust_remote_code=True) + else: + self.encoder = AutoModel.from_pretrained(config.model_name, trust_remote_code=True) + + self.encoder.resize_token_embeddings(len(self.processor.tokenizer)) + self.hidden_size = self.encoder.config.hidden_size + + # Span representation layer + self.span_rep = SpanRepLayer( + span_mode="markerV0", + hidden_size=self.hidden_size, + max_width=self.max_width, + dropout=0.1, + ) + + # Classifier for classification tasks + self.classifier = create_mlp( + input_dim=self.hidden_size, + intermediate_dims=[self.hidden_size * 2], + output_dim=1, + dropout=0., + activation="relu", + add_layer_norm=False + ) + + # Count prediction layer + self.count_pred = create_mlp( + input_dim=self.hidden_size, + intermediate_dims=[self.hidden_size * 2], + output_dim=20, + dropout=0., + activation="relu", + add_layer_norm=False + ) + + # Count embedding module + if config.counting_layer == "count_lstm": + self.count_embed = CountLSTM(self.hidden_size) + elif config.counting_layer == "count_lstm_moe": + self.count_embed = CountLSTMoE( + hidden_size=self.hidden_size, + n_experts=4, + ffn_mult=2, + dropout=0.1 + ) + elif config.counting_layer == "count_lstm_v2": + self.count_embed = CountLSTMv2(hidden_size=self.hidden_size) + + # LoRA adapter state + self._lora_layers = {} + self._adapter_config = None + + self._print_config(config) + + def _print_config(self, config): + print("=" * 60) + print("🧠 Model Configuration") + print("=" * 60) + print(f"Encoder model : {config.model_name}") + print(f"Counting layer : {config.counting_layer}") + print(f"Token pooling : {config.token_pooling}") + print("=" * 60) + + # ========================================================================= + # Main Forward Pass + # ========================================================================= + + def forward( + self, + batch: PreprocessedBatch, + return_individual_losses: bool = False + ) -> Dict[str, torch.Tensor]: + """ + Forward pass on preprocessed batch. + + Args: + batch: PreprocessedBatch from processor.collate_fn_train() + return_individual_losses: If True, return per-sample losses + + Returns: + Dict with: + - total_loss: Sum of all losses + - classification_loss: Classification task loss + - structure_loss: Span extraction loss + - count_loss: Count prediction loss + - batch_size: Number of valid samples + """ + if len(batch) == 0: + return self._empty_loss_dict() + + device = next(self.parameters()).device + batch = batch.to(device) + + # Encode batch through transformer + all_token_embs, all_schema_embs = self._encode_batch(batch) + + # Compute losses for each sample + cls_losses = [] + struct_losses = [] + count_losses = [] + individual = [] + valid_samples = 0 + + for i in range(len(batch)): + try: + sample_losses = self._compute_sample_loss( + token_embeddings=all_token_embs[i], + embs_per_schema=all_schema_embs[i], + task_types=batch.task_types[i], + structure_labels=batch.structure_labels[i], + device=device + ) + + cls_losses.append(sample_losses["classification"]) + struct_losses.append(sample_losses["structure"]) + count_losses.append(sample_losses["count"]) + + if return_individual_losses: + individual.append({ + "total_loss": ( + sample_losses["classification"] + + sample_losses["structure"] + + sample_losses["count"] + ).item(), + "classification_loss": sample_losses["classification"].item(), + "structure_loss": sample_losses["structure"].item(), + "count_loss": sample_losses["count"].item(), + }) + + valid_samples += 1 + + except Exception as e: + print(f"Error processing sample {i}: {e}") + zero = torch.tensor(0.0, device=device) + cls_losses.append(zero) + struct_losses.append(zero) + count_losses.append(zero) + + if return_individual_losses: + individual.append({ + "total_loss": 0.0, + "classification_loss": 0.0, + "structure_loss": 0.0, + "count_loss": 0.0, + "error": str(e) + }) + + if valid_samples == 0: + result = self._empty_loss_dict() + if return_individual_losses: + result["individual_losses"] = individual + return result + + # Aggregate losses + total_cls = torch.stack(cls_losses).sum() + total_struct = torch.stack(struct_losses).sum() + total_count = torch.stack(count_losses).sum() + total_loss = total_cls + total_struct + total_count + + result = { + "total_loss": total_loss, + "classification_loss": total_cls, + "structure_loss": total_struct, + "count_loss": total_count, + "batch_size": valid_samples + } + + if return_individual_losses: + result["individual_losses"] = individual + + return result + + def _empty_loss_dict(self) -> Dict[str, torch.Tensor]: + """Return empty loss dictionary.""" + device = next(self.parameters()).device + return { + "total_loss": torch.tensor(0.0, device=device, requires_grad=True), + "classification_loss": torch.tensor(0.0, device=device), + "structure_loss": torch.tensor(0.0, device=device), + "count_loss": torch.tensor(0.0, device=device), + "batch_size": 0 + } + + # ========================================================================= + # Encoding + # ========================================================================= + + def _encode_batch( + self, + batch: PreprocessedBatch + ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: + """ + Encode batch through transformer and extract embeddings. + + Args: + batch: PreprocessedBatch with input_ids and attention_mask + + Returns: + - all_token_embs: List of (text_len, hidden) per sample + - all_schema_embs: List of schema embeddings per sample + """ + # Forward through encoder + outputs = self.encoder( + input_ids=batch.input_ids, + attention_mask=batch.attention_mask + ) + token_embeddings = outputs.last_hidden_state + + # Extract embeddings using processor + return self.processor.extract_embeddings_from_batch( + token_embeddings, + batch.input_ids, + batch + ) + + # ========================================================================= + # Loss Computation + # ========================================================================= + + def _compute_sample_loss( + self, + token_embeddings: torch.Tensor, + embs_per_schema: List[List[torch.Tensor]], + task_types: List[str], + structure_labels: List[Any], + device: torch.device + ) -> Dict[str, torch.Tensor]: + """ + Compute all losses for a single sample. + + Args: + token_embeddings: (text_len, hidden) text token embeddings + embs_per_schema: List of schema embeddings + task_types: Task type for each schema + structure_labels: Labels for each schema + device: Computation device + + Returns: + Dict with classification, structure, and count losses + """ + cls_loss = torch.tensor(0.0, device=device) + struct_loss = torch.tensor(0.0, device=device) + count_loss = torch.tensor(0.0, device=device) + + # Compute span representations if needed + has_span_task = any(t != "classifications" for t in task_types) + span_info = None + if has_span_task and token_embeddings.numel() > 0: + span_info = self.compute_span_rep(token_embeddings) + + all_counts = [] + all_p_embs = [] + + for i, task_type in enumerate(task_types): + if not embs_per_schema[i]: + continue + + schema_emb = torch.stack(embs_per_schema[i]) + + if task_type == "classifications": + # Classification loss + cls_embeds = schema_emb[1:] # Skip [P] token + logits = self.classifier(cls_embeds).squeeze(-1) + labels = torch.tensor(structure_labels[i], dtype=torch.float, device=device) + cls_loss = cls_loss + F.binary_cross_entropy_with_logits( + logits, labels, reduction="sum" + ) + else: + # Structure loss + structure = structure_labels[i] + + if structure[0] == 0: + # No instances to extract + continue + + if span_info is not None: + struct_loss = struct_loss + self.compute_struct_loss( + span_info["span_rep"], + schema_emb, + structure, + span_info["span_mask"] + ) + + # Collect for count loss (skip entities) + if task_type != "entities": + all_counts.append(min(structure[0], 19)) + all_p_embs.append(schema_emb[0]) + + # Count loss + if all_counts and all_p_embs: + counts = torch.tensor(all_counts, dtype=torch.long, device=device) + p_embs = torch.stack(all_p_embs) + count_loss = F.cross_entropy(self.count_pred(p_embs), counts, reduction="sum") + + return { + "classification": cls_loss, + "structure": struct_loss, + "count": count_loss + } + + # ========================================================================= + # Span Representation + # ========================================================================= + + def compute_span_rep(self, token_embeddings: torch.Tensor) -> Dict[str, Any]: + """ + Compute span representations for token embeddings. + + Args: + token_embeddings: (text_len, hidden) token embeddings + + Returns: + Dict with span_rep, spans_idx, and span_mask + """ + text_length = len(token_embeddings) + device = token_embeddings.device + + spans_idx = [] + for i in range(text_length): + for j in range(self.max_width): + if i + j < text_length: + spans_idx.append((i, i + j)) + else: + spans_idx.append((-1, -1)) + + spans_idx = torch.tensor([spans_idx], dtype=torch.long, device=device) + + # Mask invalid spans + span_mask = (spans_idx[:, :, 0] == -1) | (spans_idx[:, :, 1] == -1) + + # Replace invalid with (0, 0) for safe indexing + safe_spans = torch.where( + span_mask.unsqueeze(-1), + torch.zeros_like(spans_idx), + spans_idx + ) + + # Compute span representations + span_rep = self.span_rep( + token_embeddings.unsqueeze(0), + safe_spans + ).squeeze(0) + + return { + "span_rep": span_rep, + "spans_idx": spans_idx, + "span_mask": span_mask + } + + def compute_struct_loss( + self, + span_rep: torch.Tensor, + schema_emb: torch.Tensor, + structure: List[Any], + span_mask: torch.Tensor, + masking_rate: float = 0.5 + ) -> torch.Tensor: + """ + Compute structure extraction loss with negative span masking. + + Args: + span_rep: (num_spans, hidden) span representations + schema_emb: (num_fields + 1, hidden) schema embeddings + structure: [count, spans] structure labels + span_mask: (1, num_spans) mask for invalid spans + masking_rate: Probability of masking negative spans + + Returns: + Structure loss tensor + """ + gold_count = min(structure[0], 19) + struct_proj = self.count_embed(schema_emb[1:], gold_count) + scores = torch.einsum('lkd,bpd->bplk', span_rep, struct_proj) + + # Create label tensor + labs = torch.zeros_like(scores) + + for i in range(gold_count): + gold_spans = structure[1][i] + for k, span in enumerate(gold_spans): + if span is None or span == (-1, -1): + continue + if isinstance(span, tuple): + start, end = span + width = end - start + if 0 <= start < scores.shape[2] and 0 <= width < scores.shape[3]: + labs[i, k, start, width] = 1 + elif isinstance(span, list): + for sub in span: + if sub is None or sub == (-1, -1): + continue + start, end = sub + width = end - start + if 0 <= start < scores.shape[2] and 0 <= width < scores.shape[3]: + labs[i, k, start, width] = 1 + + # Apply negative masking + if masking_rate > 0.0 and self.training: + negative = (labs == 0) + random_mask = torch.rand_like(scores) < masking_rate + to_mask = negative & random_mask + loss_mask = (~to_mask).float() + else: + loss_mask = torch.ones_like(scores) + + # Compute masked loss + loss = F.binary_cross_entropy_with_logits(scores, labs, reduction="none") + loss = loss * loss_mask + loss = loss.view(loss.shape[0], loss.shape[1], -1) * (~span_mask[0]).float() + + return loss.sum() + + # ========================================================================= + # Hugging Face Methods + # ========================================================================= + + def push_to_hub(self, repo_id: str, private: bool = True): + """Push model to Hugging Face Hub.""" + with tempfile.TemporaryDirectory() as tmp_dir: + self.save_pretrained(tmp_dir) + super().push_to_hub(repo_id=repo_id, save_dir=tmp_dir, private=private) + self.processor.tokenizer.push_to_hub(repo_id) + + @classmethod + def from_pretrained(cls, repo_or_dir: str, **kwargs): + """ + Load model from Hugging Face Hub or local directory. + + To use a LoRA adapter: + 1. Load the base model first + 2. Then load the adapter using model.load_adapter() + + Example: + model = Extractor.from_pretrained("base-model-name") + model.load_adapter("path/to/adapter") + """ + from huggingface_hub import hf_hub_download + + def download_or_local(repo, filename): + if os.path.isdir(repo): + return os.path.join(repo, filename) + return hf_hub_download(repo, filename) + + config_path = download_or_local(repo_or_dir, "config.json") + config = cls.config_class.from_pretrained(config_path) + + encoder_config_path = download_or_local(repo_or_dir, "encoder_config/config.json") + encoder_config = AutoConfig.from_pretrained(encoder_config_path) + + tokenizer = AutoTokenizer.from_pretrained(repo_or_dir) + model = cls(config, encoder_config=encoder_config, tokenizer=tokenizer) + + # Load weights + try: + model_path = download_or_local(repo_or_dir, "model.safetensors") + state_dict = load_file(model_path) + except Exception: + model_path = download_or_local(repo_or_dir, "pytorch_model.bin") + state_dict = torch.load(model_path, map_location="cpu") + + # Handle embedding size mismatch + try: + saved_emb = state_dict["encoder.embeddings.word_embeddings.weight"] + model_emb = model.encoder.embeddings.word_embeddings.weight + if saved_emb.shape[0] != model_emb.shape[0]: + extra = model_emb.shape[0] - saved_emb.shape[0] + state_dict["encoder.embeddings.word_embeddings.weight"] = torch.cat([ + saved_emb, + torch.randn(extra, saved_emb.shape[1]) * 0.02 + ], dim=0) + except KeyError: + pass + + model.load_state_dict(state_dict) + return model + + def load_adapter(self, adapter_path: str) -> 'Extractor': + """ + Load a LoRA adapter onto this model. + + If an adapter is already loaded, it will be unloaded first. + + Args: + adapter_path: Path to adapter directory + + Returns: + self for method chaining + + Example: + model.load_adapter("./legal_adapter") + results = model.extract_entities(text, entities) + """ + from gliner2.training.lora import load_lora_adapter, LoRAAdapterConfig + + # Load adapter config + config = LoRAAdapterConfig.load(adapter_path) + + self._lora_layers = load_lora_adapter(self, adapter_path, auto_unload=True) + self._adapter_config = config + return self + + def unload_adapter(self) -> 'Extractor': + """ + Unload current LoRA adapter, restoring base model. + + Returns: + self for method chaining + """ + from gliner2.training.lora import unload_lora_adapter + + if self._lora_layers: + unload_lora_adapter(self) + self._lora_layers = {} + self._adapter_config = None + return self + + def merge_lora(self) -> 'Extractor': + """ + Merge LoRA weights into base model and remove adapter structure. + + After calling this, the model will have standard Linear layers with + merged weights. LoRA adapters are permanently removed. + + Returns: + self for method chaining + + Raises: + ValueError: If no adapter is loaded + + Example: + model.load_adapter("./my_adapter") + model.merge_lora() # Now model has merged weights, no LoRA + model.save_pretrained("./merged_model") + """ + if not self._lora_layers: + raise ValueError("No adapter loaded. Nothing to merge.") + + from gliner2.training.lora import merge_lora_weights + merge_lora_weights(self) + self._lora_layers = {} + self._adapter_config = None + return self + + def save_adapter(self, save_path: str) -> None: + """ + Save only the LoRA adapter (not full model). + + Args: + save_path: Directory to save adapter + + Raises: + ValueError: If no adapter is loaded + """ + if not self._lora_layers: + raise ValueError("No adapter loaded. Use save_pretrained for full model.") + + from gliner2.training.lora import save_lora_adapter + save_lora_adapter(self, save_path) + + @property + def has_adapter(self) -> bool: + """Check if an adapter is currently loaded.""" + return bool(self._lora_layers) + + @property + def adapter_config(self): + """Get config of loaded adapter, or None.""" + return self._adapter_config + + def save_pretrained( + self, + save_directory: str, + save_adapter_only: bool = False, + merge_lora: bool = True, + **kwargs + ): + """ + Save model to directory. + + Args: + save_directory: Where to save + save_adapter_only: If True and adapter loaded, save only adapter + merge_lora: If True and LoRA active, merge LoRA weights into base + model and remove adapter structure before saving. + WARNING: This permanently removes LoRA from the model instance. + """ + if save_adapter_only: + if not self._lora_layers: + raise ValueError("save_adapter_only=True but no adapter loaded") + self.save_adapter(save_directory) + return + + # Handle LoRA merging if requested + if merge_lora and self._lora_layers: + self.merge_lora() + + # Original save logic + os.makedirs(save_directory, exist_ok=True) + self.config.save_pretrained(save_directory) + + encoder_config_path = os.path.join(save_directory, "encoder_config") + os.makedirs(encoder_config_path, exist_ok=True) + self.encoder.config.save_pretrained(encoder_config_path) + + model_path = os.path.join(save_directory, "model.safetensors") + save_file(self.state_dict(), model_path) + + self.processor.tokenizer.save_pretrained(save_directory) \ No newline at end of file diff --git a/packages/GLiNER2/gliner2/old_trainer.py b/packages/GLiNER2/gliner2/old_trainer.py new file mode 100644 index 0000000..d50bd30 --- /dev/null +++ b/packages/GLiNER2/gliner2/old_trainer.py @@ -0,0 +1,322 @@ +""" +GLiNER2 Trainer with Optimized DataLoader-based Preprocessing + +This module provides training utilities that leverage parallel preprocessing +via DataLoader workers for maximum GPU utilization. +""" + +import json +import random +from typing import Union, List + +import torch +from torch.utils.data import Dataset, DataLoader +from transformers import Trainer, TrainingArguments + +from gliner2.processor import SchemaTransformer, PreprocessedBatch, SamplingConfig + + +# ============================================================================= +# Dataset +# ============================================================================= + +class ExtractorDataset(Dataset): + """ + Dataset for GLiNER2 training. + + Returns (text, schema) tuples that are processed by the collate function. + + Args: + data_paths: Path or list of paths to JSONL training files + shuffle: Whether to shuffle data on load (default: True) + + JSONL Format: + {"input": "text here", "output": {"entities": {...}, ...}} + """ + + def __init__(self, data_paths: Union[str, List[str]], shuffle: bool = True): + if isinstance(data_paths, str): + data_paths = [data_paths] + + print(f"Loading {len(data_paths)} file(s) for training...") + + self.data = [] + for path in data_paths: + with open(path, "r", encoding="utf-8") as f: + self.data.extend([json.loads(line) for line in f]) + + if shuffle: + random.shuffle(self.data) + + print(f"Loaded {len(self.data)} records from {len(data_paths)} file(s).") + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, idx: int) -> tuple: + """Return (text, schema) tuple.""" + record = self.data[idx] + return record["input"], record["output"] + + +# ============================================================================= +# Data Collator +# ============================================================================= + +class ExtractorDataCollator: + """ + Data collator that uses processor's collate function. + + This enables parallel preprocessing via DataLoader workers. + + Args: + processor: SchemaTransformer instance + is_training: Whether in training mode (enables augmentation) + """ + + def __init__(self, processor: SchemaTransformer, is_training: bool = True): + self.processor = processor + self.is_training = is_training + + def __call__(self, batch: List[tuple]) -> PreprocessedBatch: + """ + Collate batch of (text, schema) tuples into PreprocessedBatch. + + Args: + batch: List of (text, schema) tuples from dataset + + Returns: + PreprocessedBatch ready for model.forward() + """ + if self.is_training: + return self.processor.collate_fn_train(batch) + else: + return self.processor.collate_fn_inference(batch) + + +# ============================================================================= +# Trainer +# ============================================================================= + +class ExtractorTrainer(Trainer): + """ + Trainer for GLiNER2 with optimized preprocessing. + + Key features: + - Parallel preprocessing via DataLoader workers + - Separate learning rates for encoder and other layers + - Optional classifier-only fine-tuning + - FP16 support + - Gradient accumulation + + Example: + >>> processor = SchemaTransformer(model_name, sampling_config=config) + >>> collator = ExtractorDataCollator(processor, is_training=True) + >>> + >>> trainer = ExtractorTrainer( + ... model=model, + ... args=TrainingArguments( + ... output_dir="./output", + ... per_device_train_batch_size=32, + ... dataloader_num_workers=8, # Parallel preprocessing! + ... dataloader_pin_memory=True, + ... ), + ... train_dataset=dataset, + ... data_collator=collator, + ... encoder_lr=1e-5, + ... custom_lr=5e-4, + ... weight_decay=0.01, + ... ) + >>> trainer.train() + """ + + def __init__( + self, + encoder_lr: float = 1e-5, + custom_lr: float = 5e-4, + weight_decay: float = 0.01, + finetune_classifier: bool = False, + **kwargs + ): + """ + Initialize trainer. + + Args: + encoder_lr: Learning rate for encoder parameters + custom_lr: Learning rate for non-encoder parameters + weight_decay: Weight decay for all parameters + finetune_classifier: If True, freeze all except classifier + **kwargs: Arguments passed to Trainer + """ + self.encoder_lr = encoder_lr + self.custom_lr = custom_lr + self.custom_weight_decay = weight_decay + self.finetune_classifier = finetune_classifier + + super().__init__(**kwargs) + + if self.finetune_classifier: + self._freeze_non_classifier() + + def _freeze_non_classifier(self): + """Freeze all parameters except classifier.""" + print("Finetuning classifier only: freezing all other parameters.") + for name, param in self.model.named_parameters(): + if not name.startswith("classifier"): + param.requires_grad = False + + def create_optimizer(self): + """Create optimizer with separate parameter groups.""" + if self.finetune_classifier: + # Only classifier parameters + classifier_params = [ + p for n, p in self.model.named_parameters() + if n.startswith("classifier") and p.requires_grad + ] + if not classifier_params: + raise ValueError("No trainable parameters in classifier.") + + groups = [{ + "params": classifier_params, + "lr": self.custom_lr, + "weight_decay": self.custom_weight_decay, + }] + else: + # Separate encoder and other parameters + encoder_params = list(self.model.encoder.parameters()) + other_params = [ + p for n, p in self.model.named_parameters() + if "encoder" not in n and p.requires_grad + ] + + groups = [ + { + "params": encoder_params, + "lr": self.encoder_lr, + "weight_decay": self.custom_weight_decay + }, + { + "params": other_params, + "lr": self.custom_lr, + "weight_decay": self.custom_weight_decay + }, + ] + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + self.optimizer = optimizer_cls(groups, **optimizer_kwargs) + + def compute_loss(self, model, inputs: PreprocessedBatch, return_outputs: bool = False, **kwargs): + """ + Compute loss on preprocessed batch. + + Args: + model: The model + inputs: PreprocessedBatch from collator + return_outputs: Whether to return outputs dict + + Returns: + Loss tensor, optionally with outputs dict + """ + # Forward pass - inputs is already PreprocessedBatch + outputs = model(inputs, return_individual_losses=False) + + # Handle empty batch + if outputs["batch_size"] == 0: + device = next(model.parameters()).device + loss = torch.tensor(0.0, device=device, requires_grad=True) + else: + loss = outputs["total_loss"] + + return (loss, outputs) if return_outputs else loss + + +# ============================================================================= +# Training Utilities +# ============================================================================= + +def create_training_dataloader( + dataset: ExtractorDataset, + processor: SchemaTransformer, + batch_size: int = 32, + num_workers: int = 8, + pin_memory: bool = True, + shuffle: bool = True, + prefetch_factor: int = 2, +) -> DataLoader: + """ + Create an optimized DataLoader for training. + + This function creates a DataLoader configured for maximum preprocessing + efficiency using parallel workers. + + Args: + dataset: ExtractorDataset instance + processor: SchemaTransformer for preprocessing + batch_size: Batch size + num_workers: Number of parallel workers for preprocessing + pin_memory: Pin memory for faster GPU transfer + shuffle: Shuffle data each epoch + prefetch_factor: Batches to prefetch per worker + + Returns: + Configured DataLoader + + Example: + >>> loader = create_training_dataloader( + ... dataset=train_dataset, + ... processor=processor, + ... batch_size=32, + ... num_workers=8, + ... ) + >>> for batch in loader: + ... batch = batch.to(device) + ... loss = model(batch)["total_loss"] + """ + collator = ExtractorDataCollator(processor, is_training=True) + + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=pin_memory, + prefetch_factor=prefetch_factor if num_workers > 0 else None, + collate_fn=collator, + persistent_workers=num_workers > 0, + ) + + +def create_inference_dataloader( + texts: List[str], + schemas: List[dict], + processor: SchemaTransformer, + batch_size: int = 32, + num_workers: int = 4, +) -> DataLoader: + """ + Create a DataLoader for inference. + + Args: + texts: List of input texts + schemas: List of schemas (same length as texts or single schema) + processor: SchemaTransformer for preprocessing + batch_size: Batch size + num_workers: Number of workers + + Returns: + DataLoader yielding PreprocessedBatch + """ + # Handle single schema for all texts + if len(schemas) == 1: + schemas = schemas * len(texts) + + dataset = list(zip(texts, schemas)) + collator = ExtractorDataCollator(processor, is_training=False) + + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + collate_fn=collator, + ) \ No newline at end of file diff --git a/packages/GLiNER2/gliner2/processor.py b/packages/GLiNER2/gliner2/processor.py new file mode 100644 index 0000000..73956aa --- /dev/null +++ b/packages/GLiNER2/gliner2/processor.py @@ -0,0 +1,1072 @@ +""" +GLiNER2 Schema Transformer with Optimized Batch Processing + +This module handles all preprocessing for GLiNER2, with efficient batching +via DataLoader collate functions for parallel preprocessing. +""" + +import copy +import random +import re +from dataclasses import dataclass, field +from typing import Any, Dict, Tuple, Iterator, List +import torch +from transformers import AutoTokenizer + + +# ============================================================================= +# Data Structures +# ============================================================================= + +@dataclass +class TransformedRecord: + """Single transformed record ready for batching.""" + input_ids: List[int] + mapped_indices: List[Tuple[str, int, int]] + schema_tokens_list: List[List[str]] + text_tokens: List[str] + structure_labels: List[Any] + task_types: List[str] + start_token_idx: List[int] + end_token_idx: List[int] + text: str + schema: Dict[str, Any] + num_schemas: int = field(init=False) + + def __post_init__(self): + self.num_schemas = len(self.schema_tokens_list) + + +@dataclass +class PreprocessedBatch: + """GPU-ready batch for training/inference.""" + input_ids: torch.Tensor # (batch, max_seq_len) + attention_mask: torch.Tensor # (batch, max_seq_len) + mapped_indices: List[List[Tuple]] # Per-sample token mappings + schema_counts: List[int] # Number of schemas per sample + original_lengths: List[int] # Original sequence lengths + structure_labels: List[List[Any]] # Ground truth labels + task_types: List[List[str]] # Task types per schema + text_tokens: List[List[str]] # Original text tokens + schema_tokens_list: List[List[List[str]]] # Schema tokens per sample + start_mappings: List[List[int]] # Char position start mappings + end_mappings: List[List[int]] # Char position end mappings + original_texts: List[str] # For result formatting + original_schemas: List[Dict] # For result formatting + + def to(self, device: torch.device) -> 'PreprocessedBatch': + """Move tensors to device.""" + return PreprocessedBatch( + input_ids=self.input_ids.to(device), + attention_mask=self.attention_mask.to(device), + mapped_indices=self.mapped_indices, + schema_counts=self.schema_counts, + original_lengths=self.original_lengths, + structure_labels=self.structure_labels, + task_types=self.task_types, + text_tokens=self.text_tokens, + schema_tokens_list=self.schema_tokens_list, + start_mappings=self.start_mappings, + end_mappings=self.end_mappings, + original_texts=self.original_texts, + original_schemas=self.original_schemas, + ) + + def pin_memory(self) -> 'PreprocessedBatch': + """Pin tensors to memory for faster GPU transfer.""" + return PreprocessedBatch( + input_ids=self.input_ids.pin_memory(), + attention_mask=self.attention_mask.pin_memory(), + mapped_indices=self.mapped_indices, + schema_counts=self.schema_counts, + original_lengths=self.original_lengths, + structure_labels=self.structure_labels, + task_types=self.task_types, + text_tokens=self.text_tokens, + schema_tokens_list=self.schema_tokens_list, + start_mappings=self.start_mappings, + end_mappings=self.end_mappings, + original_texts=self.original_texts, + original_schemas=self.original_schemas, + ) + + def __contains__(self, key: str) -> bool: + """Check if key is a field name. Required for HuggingFace Trainer compatibility.""" + return hasattr(self, key) + + def __iter__(self): + """Iterate over field names. Required for HuggingFace Trainer compatibility.""" + return iter(self.__dataclass_fields__.keys()) + + def __getitem__(self, key): + """Get field by name. Required for HuggingFace Trainer compatibility.""" + if isinstance(key, str): + return getattr(self, key) + raise KeyError(f"PreprocessedBatch does not support integer indexing: {key}") + + def __len__(self) -> int: + return self.input_ids.shape[0] + + +# ============================================================================= +# Tokenizer +# ============================================================================= + +class WhitespaceTokenSplitter: + """Fast regex-based tokenizer for text splitting.""" + __slots__ = () + + _PATTERN = re.compile( + r"""(?:https?://[^\s]+|www\.[^\s]+) + |[a-z0-9._%+-]+@[a-z0-9.-]+\.[a-z]{2,} + |@[a-z0-9_]+ + |\w+(?:[-_]\w+)* + |\S""", + re.VERBOSE | re.IGNORECASE, + ) + + def __call__(self, text: str, lower: bool = True) -> Iterator[Tuple[str, int, int]]: + if lower: + text = text.lower() + for m in self._PATTERN.finditer(text): + yield m.group(), m.start(), m.end() + + +# ============================================================================= +# Sampling Configuration +# ============================================================================= + +@dataclass +class SamplingConfig: + """Configuration for stochastic sampling during training.""" + # JSON Structures + remove_json_structure_prob: float = 0.2 + shuffle_json_fields: bool = True + remove_json_field_prob: float = 0.2 + # Entities + remove_entities_prob: float = 0.0 + shuffle_entities: bool = False + remove_entity_prob: float = 0.0 + synthetic_entity_label_prob: float = 0.2 + # Relations + remove_relations_prob: float = 0.2 + swap_head_tail_prob: float = 0.2 + # Classifications + remove_classification_prob: float = 0.0 + shuffle_classification_labels: bool = True + remove_classification_label_prob: float = 0.5 + synthetic_label_prob: float = 0.5 + include_true_label_prob: float = 0.5 + max_num_labels: int = 1000 + + +# ============================================================================= +# Main Processor Class +# ============================================================================= + +class SchemaTransformer: + """ + Schema-based text transformer for GLiNER2. + + Provides efficient batch preprocessing via collate functions + for parallel DataLoader preprocessing. + """ + + # Special tokens + SEP_STRUCT = "[SEP_STRUCT]" + SEP_TEXT = "[SEP_TEXT]" + P_TOKEN = "[P]" + C_TOKEN = "[C]" + E_TOKEN = "[E]" + R_TOKEN = "[R]" + L_TOKEN = "[L]" + EXAMPLE_TOKEN = "[EXAMPLE]" + OUTPUT_TOKEN = "[OUTPUT]" + DESC_TOKEN = "[DESCRIPTION]" + + SPECIAL_TOKENS = [ + SEP_STRUCT, SEP_TEXT, P_TOKEN, C_TOKEN, E_TOKEN, + R_TOKEN, L_TOKEN, EXAMPLE_TOKEN, OUTPUT_TOKEN, DESC_TOKEN + ] + + def __init__( + self, + model_name: str = None, + tokenizer=None, + sampling_config: SamplingConfig = None, + token_pooling: str = "first" + ): + if model_name is None and tokenizer is None: + raise ValueError("Either model_name or tokenizer must be provided.") + + self.token_pooling = token_pooling if token_pooling in ["first", "mean", "max"] else "first" + self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_name) + self.word_splitter = WhitespaceTokenSplitter() + self.sampling_config = sampling_config or SamplingConfig() + self.is_training = False + + # Add special tokens + self.tokenizer.add_special_tokens({ + "additional_special_tokens": self.SPECIAL_TOKENS + }) + + # OPT-1: Pre-compute special token IDs for fast lookup in embedding extraction + self._special_ids = frozenset( + self.tokenizer.convert_tokens_to_ids(t) + for t in (self.P_TOKEN, self.C_TOKEN, self.E_TOKEN, self.R_TOKEN, self.L_TOKEN) + ) + + # OPT-6: Cache tokenized forms of special tokens and common punctuation + self._token_cache = {} + for tok in self.SPECIAL_TOKENS + ["(", ")", ",", "|"]: + self._token_cache[tok] = self.tokenizer.tokenize(tok) + + def change_mode(self, is_training: bool): + """Switch between training and inference mode.""" + self.is_training = is_training + + # ========================================================================= + # Main Public API: Collate Functions + # ========================================================================= + + def collate_fn_train( + self, + batch: List[Tuple[str, Dict]] + ) -> PreprocessedBatch: + """ + Collate function for training DataLoader. + + Use this with DataLoader for parallel preprocessing: + + loader = DataLoader( + dataset, + batch_size=32, + collate_fn=processor.collate_fn_train, + num_workers=8 + ) + + Args: + batch: List of (text, schema) tuples from dataset + + Returns: + PreprocessedBatch ready for model.forward() + """ + self.is_training = True + return self._collate_batch(batch) + + def collate_fn_inference( + self, + batch: List[Tuple[str, Any]] + ) -> PreprocessedBatch: + """ + Collate function for inference DataLoader. + + Args: + batch: List of (text, schema) tuples + + Returns: + PreprocessedBatch for batch_extract + """ + self.is_training = False + return self._collate_batch(batch) + + def transform_and_format( + self, + text: str, + schema: Dict[str, Any] + ) -> TransformedRecord: + """ + Transform and format a single record. + + This is the main preprocessing entry point for single records. + For batch processing, use collate_fn_train/collate_fn_inference. + + Args: + text: Input text + schema: Schema dictionary + + Returns: + TransformedRecord ready for batching + """ + record = {"text": text, "schema": copy.deepcopy(schema)} + return self._transform_record(record) + + # ========================================================================= + # Internal: Batch Processing + # ========================================================================= + + def _collate_batch( + self, + batch: List[Tuple[str, Any]] + ) -> PreprocessedBatch: + """Internal collate implementation.""" + transformed_records = [] + + for text, schema in batch: + # Handle Schema objects + if hasattr(schema, 'build'): + schema = schema.build() + elif hasattr(schema, 'schema'): + schema = schema.schema + + # Ensure text ends with punctuation + if text and not text.endswith(('.', '!', '?')): + text = text + "." + elif not text: + text = "." + + record = {"text": text, "schema": copy.deepcopy(schema)} + + try: + transformed = self._transform_record(record) + transformed_records.append(transformed) + except Exception as e: + # Create minimal fallback record + transformed_records.append(self._create_fallback_record(text, schema)) + + return self._pad_batch(transformed_records) + + def _transform_record(self, record: Dict[str, Any]) -> TransformedRecord: + """Transform a single record (internal).""" + # OPT-4: Caller (_collate_batch) already deepcopies the schema. + # Only deepcopy here for direct callers (transform_and_format). + text, schema = record["text"], record["schema"] + + # Build classification prefix + prefix = self._build_classification_prefix(schema) + + # Save a copy of the original schema BEFORE wrapping modifies it + # This preserves choice field info for extraction + original_schema = copy.deepcopy(schema) + + # Handle classification field wrapping + if prefix: + self._wrap_classification_fields(schema, prefix) + + # Tokenize text + text_tokens = [] + start_idx_map = [] + end_idx_map = [] + for tkn, start, end in self.word_splitter(text, lower=True): + text_tokens.append(tkn) + start_idx_map.append(start) + end_idx_map.append(end) + + if prefix: + text_tokens = prefix + text_tokens + len_prefix = len(prefix) + + # Infer schema + processed = self._infer_from_json(schema) + + # Build outputs + results = self._build_outputs( + processed, schema, text_tokens, len_prefix + ) + + # Format input + schema_tokens_list = [r["schema_tokens"] for r in results] + format_result = self._format_input_with_mapping(schema_tokens_list, text_tokens) + + return TransformedRecord( + input_ids=format_result["input_ids"], + mapped_indices=format_result["mapped_indices"], + schema_tokens_list=schema_tokens_list, + text_tokens=text_tokens, + structure_labels=[r["output"] for r in results], + task_types=[r["task_type"] for r in results], + start_token_idx=start_idx_map, + end_token_idx=end_idx_map, + text=text, + schema=original_schema, # Use original schema with choice info preserved + ) + + def _pad_batch( + self, + records: List[TransformedRecord] + ) -> PreprocessedBatch: + """Pad transformed records into a batch.""" + if not records: + return self._empty_batch() + + max_len = max(len(r.input_ids) for r in records) + batch_size = len(records) + + # Pre-allocate tensors + input_ids = torch.zeros((batch_size, max_len), dtype=torch.long) + attention_mask = torch.zeros((batch_size, max_len), dtype=torch.long) + original_lengths = [] + + for i, rec in enumerate(records): + seq_len = len(rec.input_ids) + input_ids[i, :seq_len] = torch.tensor(rec.input_ids, dtype=torch.long) + attention_mask[i, :seq_len] = 1 + original_lengths.append(seq_len) + + return PreprocessedBatch( + input_ids=input_ids, + attention_mask=attention_mask, + mapped_indices=[r.mapped_indices for r in records], + schema_counts=[r.num_schemas for r in records], + original_lengths=original_lengths, + structure_labels=[r.structure_labels for r in records], + task_types=[r.task_types for r in records], + text_tokens=[r.text_tokens for r in records], + schema_tokens_list=[r.schema_tokens_list for r in records], + start_mappings=[r.start_token_idx for r in records], + end_mappings=[r.end_token_idx for r in records], + original_texts=[r.text for r in records], + original_schemas=[r.schema for r in records], + ) + + def _empty_batch(self) -> PreprocessedBatch: + """Create empty batch for edge cases.""" + return PreprocessedBatch( + input_ids=torch.zeros((0, 0), dtype=torch.long), + attention_mask=torch.zeros((0, 0), dtype=torch.long), + mapped_indices=[], + schema_counts=[], + original_lengths=[], + structure_labels=[], + task_types=[], + text_tokens=[], + schema_tokens_list=[], + start_mappings=[], + end_mappings=[], + original_texts=[], + original_schemas=[], + ) + + def _create_fallback_record(self, text: str, schema: Dict) -> TransformedRecord: + """Create minimal valid record for failed transformations.""" + dummy_tokens = [ + "(", "[P]", "dummy", "(", "[E]", "entity", ")", ")" + ] + format_result = self._format_input_with_mapping([dummy_tokens], ["."]) + + return TransformedRecord( + input_ids=format_result["input_ids"], + mapped_indices=format_result["mapped_indices"], + schema_tokens_list=[dummy_tokens], + text_tokens=["."], + structure_labels=[[1, [[(0, 0)]]]], + task_types=["entities"], + start_token_idx=[0], + end_token_idx=[1], + text=text or ".", + schema=schema or {}, + ) + + # ========================================================================= + # Internal: Schema Processing + # ========================================================================= + + def _build_classification_prefix(self, schema: Dict[str, Any]) -> List[str]: + """Build classification prefix tokens.""" + prefix_tokens = [] + + for struct in schema.get("json_structures", []): + for parent, fields in struct.items(): + cls_fields = [ + (fname, fval) for fname, fval in fields.items() + if isinstance(fval, dict) and "value" in fval and "choices" in fval + ] + + if self.is_training: + random.shuffle(cls_fields) + + inner = [] + for fname, fval in cls_fields: + choices = fval["choices"].copy() + if self.is_training: + random.shuffle(choices) + + choice_tokens = [] + for i, c in enumerate(choices): + if i > 0: + choice_tokens.append('|') + choice_tokens.append(c) + + inner.extend([fname, '('] + choice_tokens + [')', ',']) + + if inner: + inner = inner[:-1] + prefix_tokens.extend(['(', f"{parent}:", *inner, ')']) + + return prefix_tokens + + def _wrap_classification_fields(self, schema: Dict, prefix: List[str]): + """Wrap classification field values with [selection] prefix.""" + + def wrap(val): + if isinstance(val, list): + return [f"[selection]{v}" for v in val] + return f"[selection]{val}" + + cls_keys = { + f"{parent}.{fname}" + for struct in schema.get("json_structures", []) + for parent, fields in struct.items() + for fname, fval in fields.items() + if isinstance(fval, dict) and {"value", "choices"} <= fval.keys() + } + + for struct in schema.get("json_structures", []): + for parent, fields in struct.items(): + for fname in list(fields): + key = f"{parent}.{fname}" + if key not in cls_keys: + continue + fval = fields[fname] + raw = fval["value"] if isinstance(fval, dict) else fval + fields[fname] = wrap(raw) + + def _infer_from_json(self, schema: Dict[str, Any]) -> Dict[str, Any]: + """Infer schemas and labels from JSON schema.""" + schemas = [] + labels = [] + types = [] + + sampling = self.sampling_config if self.is_training else None + + # Process JSON structures + self._process_json_structures(schema, schemas, labels, types, sampling) + + # Process entities + self._process_entities(schema, schemas, labels, types, sampling) + + # Process relations + self._process_relations(schema, schemas, labels, types, sampling) + + # Process classifications + self._process_classifications(schema, schemas, labels, types, sampling) + + # Shuffle task order during training + if sampling: + order = list(range(len(types))) + random.shuffle(order) + schemas = [schemas[i] for i in order] + labels = [labels[i] for i in order] + types = [types[i] for i in order] + + return { + "schemas": schemas, + "structure_labels": labels, + "task_types": types, + "new_schema": schema + } + + def _process_json_structures(self, schema, schemas, labels, types, sampling): + """Process JSON structure schemas.""" + if "json_structures" not in schema: + return + + json_descs = schema.get("json_descriptions", {}) + groups = {} + + for item in schema["json_structures"]: + for parent, fields in item.items(): + groups.setdefault(parent, []).append(fields) + + for parent, occurrences in groups.items(): + if sampling and random.random() < sampling.remove_json_structure_prob: + continue + + all_fields = set() + for occ in occurrences: + all_fields.update(occ.keys()) + common = list(all_fields) + + if sampling and sampling.shuffle_json_fields: + random.shuffle(common) + + chosen = [f for f in common if not ( + sampling and random.random() < sampling.remove_json_field_prob + )] + if not chosen: + continue + + # Handle synthetic labeling + real2syn = {} + descs = json_descs.get(parent, {}) + example_modes = ["none", "descriptions"] + + if sampling and random.random() < sampling.synthetic_entity_label_prob: + example_modes.remove("none") + synthetic = [] + for i, real in enumerate(chosen, 1): + syn = f"field {i}" + real2syn[real] = syn + synthetic.append(syn) + descs = {real2syn.get(k, k): descs.get(k, k) for k in chosen} + chosen = synthetic + + # Build spans + spans = [] + for occ in occurrences: + span = [occ.get(f) for f in chosen] + spans.append(span) + + # Dedup + uniq = [] + seen = set() + for s in spans: + key = tuple(tuple(x) if isinstance(x, list) else x for x in s) + if key not in seen: + uniq.append(s) + seen.add(key) + + # Check for empty + if all(all(c is None or c == "" for c in span) for span in uniq): + count = 0 + uniq = [] + else: + count = len(uniq) + + labels.append([count, uniq]) + + mode = random.choice(example_modes) if self.is_training else ( + "descriptions" if descs else "none" + ) + + schemas.append(self._transform_schema( + parent, chosen, self.C_TOKEN, label_descriptions=descs, example_mode=mode + )) + types.append("json_structures") + + def _process_entities(self, schema, schemas, labels, types, sampling): + """Process entity schemas.""" + if "entities" not in schema: + return + + if sampling and random.random() < sampling.remove_entities_prob: + return + + entity_fields = list(schema["entities"].keys()) + descs = schema.get("entity_descriptions", {}) + example_modes = ["none", "descriptions"] + + real2syn = {} + if sampling and random.random() < sampling.synthetic_entity_label_prob: + example_modes.remove("none") + synthetic = [] + for i, real in enumerate(entity_fields, 1): + syn = f"entity {i}" + real2syn[real] = syn + synthetic.append(syn) + descs = {real2syn.get(k, k): v for k, v in descs.items()} + schema["entities"] = {real2syn.get(k, k): v for k, v in schema["entities"].items()} + entity_fields = synthetic + + if sampling and sampling.shuffle_entities: + random.shuffle(entity_fields) + + chosen = [e for e in entity_fields if not ( + sampling and random.random() < sampling.remove_entity_prob + )] + + if chosen: + span = [schema["entities"][e] for e in chosen] + labels.append([1, [span]]) + + mode = random.choice(example_modes) if self.is_training else ( + "descriptions" if descs else "none" + ) + + schemas.append(self._transform_schema( + "entities", chosen, self.E_TOKEN, label_descriptions=descs, example_mode=mode + )) + types.append("entities") + + def _process_relations(self, schema, schemas, labels, types, sampling): + """Process relation schemas.""" + if "relations" not in schema: + return + + groups = {} + for item in schema["relations"]: + if sampling and random.random() < sampling.remove_relations_prob: + continue + for parent, fields in item.items(): + groups.setdefault(parent, []).append(fields) + + for parent, occurrences in groups.items(): + field_names = list(occurrences[0].keys()) + + if sampling and "head" in field_names and "tail" in field_names: + if random.random() < sampling.swap_head_tail_prob: + idx_h = field_names.index("head") + idx_t = field_names.index("tail") + field_names[idx_h], field_names[idx_t] = field_names[idx_t], field_names[idx_h] + + spans = [] + for occ in occurrences: + if all(f in occ for f in field_names): + spans.append([occ[f] for f in field_names]) + + if not spans: + continue + + # Dedup + seen = set() + uniq = [] + for span in spans: + t = tuple(tuple(s) if isinstance(s, list) else s for s in span) + if t not in seen: + seen.add(t) + uniq.append(span) + + labels.append([len(uniq), uniq]) + schemas.append(self._transform_schema(parent, field_names, self.R_TOKEN)) + types.append("relations") + + def _process_classifications(self, schema, schemas, labels, types, sampling): + """Process classification schemas.""" + if "classifications" not in schema: + return + + for idx, item in enumerate(schema["classifications"]): + if sampling and random.random() < sampling.remove_classification_prob: + continue + + cls_labels = item["labels"].copy() + examples = item.get("examples", []) + descs = item.get("label_descriptions", {}) or {} + + real2syn = {} + example_modes = ["few_shot", "descriptions", "both", "none"] if self.is_training else ["both"] + + if sampling and random.random() < sampling.synthetic_label_prob: + example_modes = [m for m in example_modes if m != "none"] + synthetic = [] + for i, real in enumerate(cls_labels, 1): + syn = f"label {i}" + real2syn[real] = syn + synthetic.append(syn) + cls_labels = synthetic + descs = {real2syn.get(k, k): descs.get(k, k) for k in item["labels"]} + examples = [(inp, real2syn.get(out, out)) for inp, out in examples] + + mode = random.choice(example_modes) if example_modes else "none" + + # Label dropping + if sampling and hasattr(sampling, "remove_classification_label_prob"): + drop_frac = random.betavariate(1, 1) * sampling.remove_classification_label_prob + num_remove = int(len(cls_labels) * drop_frac) + if num_remove > 0: + cls_labels = random.sample(cls_labels, len(cls_labels) - num_remove) + + max_labels = sampling.max_num_labels // 2 if mode in ["few_shot", "both", + "descriptions"] else sampling.max_num_labels + if len(cls_labels) > max_labels: + cls_labels = cls_labels[:max_labels] + + if random.random() < sampling.include_true_label_prob: + true_label = item.get("true_label", []) + if isinstance(true_label, list): + for tl in true_label: + if tl not in cls_labels: + cls_labels.append(tl) + elif true_label not in cls_labels: + cls_labels.append(true_label) + + if sampling and sampling.shuffle_classification_labels: + random.shuffle(cls_labels) + + schemas.append(self._transform_schema( + item["task"], cls_labels, self.L_TOKEN, + prompt=item.get("prompt"), examples=examples, + label_descriptions=descs, example_mode=mode + )) + types.append("classifications") + + # Update schema + schema["classifications"][idx]["labels"] = cls_labels + true_label = schema["classifications"][idx]["true_label"].copy() + schema["classifications"][idx]["true_label"] = [real2syn.get(i, i) for i in true_label] + labels.append([]) + + def _transform_schema( + self, + parent: str, + fields: List[str], + child_prefix: str, + prompt: str = None, + examples: List[Tuple[str, str]] = None, + label_descriptions: Dict[str, str] = None, + example_mode: str = "both" + ) -> List[str]: + """Transform schema into token sequence.""" + prompt_str = parent + if prompt: + prompt_str = f"{parent}: {prompt}" + + if example_mode in ["descriptions", "both"] and label_descriptions: + descs = [(l, d) for l, d in label_descriptions.items() if l in fields] + if self.is_training: + random.shuffle(descs) + for label, desc in descs: + prompt_str += f" {self.DESC_TOKEN} {label}: {desc}" + + if example_mode in ["few_shot", "both"] and examples: + if self.is_training: + random.shuffle(examples) + for inp, out in examples: + if out in fields: + out_str = out if isinstance(out, str) else ', '.join(out) + prompt_str += f" {self.EXAMPLE_TOKEN} {inp} {self.OUTPUT_TOKEN} {out_str}" + + tokens = ["(", self.P_TOKEN, prompt_str, "("] + for field in fields: + tokens.extend([child_prefix, field]) + tokens.extend([")", ")"]) + + return tokens + + def _build_outputs( + self, + processed: Dict, + schema: Dict, + text_tokens: List[str], + len_prefix: int + ) -> List[Dict]: + """Build output labels for each schema.""" + results = [] + + for schema_tokens, task_type, struct_label in zip( + processed["schemas"], + processed["task_types"], + processed["structure_labels"] + ): + if task_type != "classifications": + count, spans = struct_label + transformed = [] + + for span in spans: + positions = [] + for field in span: + if isinstance(field, list): + nested = [] + for sub in field: + if str(sub).startswith("[selection]"): + # Use case-insensitive matching for choice fields + pos = self._find_sublist( + [str(sub)[11:]], text_tokens[:len_prefix], + case_insensitive=True + ) + else: + pos = self._find_sublist( + self._tokenize_text(str(sub)), text_tokens + ) + nested.extend(pos) + positions.append(nested) + else: + if str(field).startswith("[selection]"): + # Use case-insensitive matching for choice fields + pos = self._find_sublist( + [str(field)[11:]], text_tokens[:len_prefix], + case_insensitive=True + ) + else: + pos = self._find_sublist( + self._tokenize_text(str(field)), text_tokens + ) + positions.append(pos) + transformed.append(positions) + + results.append({ + "task_type": task_type, + "schema_tokens": schema_tokens, + "output": [count, transformed] + }) + else: + cls_item = next( + (c for c in schema["classifications"] if schema_tokens[2].startswith(c["task"])), + None + ) + if cls_item is None: + raise ValueError(f"Missing classification for: {schema_tokens[2]}") + + bool_labels = [1 if l in cls_item["true_label"] else 0 for l in cls_item["labels"]] + results.append({ + "task_type": task_type, + "schema_tokens": schema_tokens, + "output": bool_labels + }) + + return results + + def _find_sublist( + self, + sub: List[str], + lst: List[str], + case_insensitive: bool = False + ) -> List[Tuple[int, int]]: + """Find all occurrences of sublist in list. + + Args: + sub: Sublist to search for + lst: List to search in + case_insensitive: If True, use case-insensitive matching + """ + if not sub or all(t == "" for t in sub): + return [(-1, -1)] + + sub_len = len(sub) + + if case_insensitive: + sub_lower = [s.lower() for s in sub] + matches = [ + (i, i + sub_len - 1) + for i in range(len(lst) - sub_len + 1) + if [t.lower() for t in lst[i:i + sub_len]] == sub_lower + ] + else: + matches = [ + (i, i + sub_len - 1) + for i in range(len(lst) - sub_len + 1) + if lst[i:i + sub_len] == sub + ] + return matches or [(-1, -1)] + + def _tokenize_text(self, text: str) -> List[str]: + """Tokenize text into words.""" + return [tok for tok, _, _ in self.word_splitter(text, lower=True)] + + # ========================================================================= + # Internal: Input Formatting + # ========================================================================= + + def _format_input_with_mapping( + self, + schema_tokens_list: List[List[str]], + text_tokens: List[str] + ) -> Dict[str, Any]: + """Format input and create token mappings.""" + # Build combined tokens + combined = [] + for struct in schema_tokens_list: + combined.extend(struct) + combined.append(self.SEP_STRUCT) + if combined: + combined.pop() + combined.append(self.SEP_TEXT) + combined.extend(text_tokens) + + # Build subword list and mappings + subwords = [] + mappings = [] + + num_schemas = len(schema_tokens_list) + text_schema_idx = num_schemas + current_schema = 0 + found_sep = False + + for orig_idx, token in enumerate(combined): + if token == self.SEP_TEXT: + seg_type = "sep" + schema_idx = text_schema_idx + found_sep = True + elif not found_sep: + seg_type = "schema" + schema_idx = current_schema + if token == self.SEP_STRUCT: + current_schema += 1 + else: + seg_type = "text" + schema_idx = text_schema_idx + + # OPT-6: Use cached tokenizations for special tokens and punctuation + if token in self._token_cache: + sub_tokens = self._token_cache[token] + else: + sub_tokens = self.tokenizer.tokenize(token) + subwords.extend(sub_tokens) + mappings.extend([(seg_type, orig_idx, schema_idx)] * len(sub_tokens)) + + input_ids = self.tokenizer.convert_tokens_to_ids(subwords) + + return { + "input_ids": input_ids, + "mapped_indices": mappings, + "subword_list": subwords + } + + # ========================================================================= + # Embedding Extraction (Called by Model) + # ========================================================================= + + def extract_embeddings_from_batch( + self, + token_embeddings: torch.Tensor, + input_ids: torch.Tensor, + batch: PreprocessedBatch + ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: + """ + Extract token and schema embeddings from encoded batch. + + Args: + token_embeddings: (batch, seq_len, hidden) from encoder + input_ids: (batch, seq_len) input token IDs + batch: PreprocessedBatch with metadata + + Returns: + - all_token_embs: List of (text_len, hidden) per sample + - all_schema_embs: List of schema embeddings per sample + """ + all_token_embs = [] + all_schema_embs = [] + + # OPT-1: Use pre-computed special token IDs instead of string comparison + special_ids = self._special_ids + + for i in range(len(batch)): + seq_len = batch.original_lengths[i] + embs = token_embeddings[i, :seq_len, :] + ids = input_ids[i, :seq_len].tolist() + mappings = batch.mapped_indices[i][:seq_len] + num_schemas = batch.schema_counts[i] + + schema_embs = [[] for _ in range(num_schemas)] + word_embs = [] + bucket = [] + last_orig = None + + for j, tid in enumerate(ids): + seg_type, orig_idx, schema_idx = mappings[j] + emb = embs[j] + + if seg_type == "schema": + if tid in special_ids: + schema_embs[schema_idx].append(emb) + + elif seg_type == "text": + if last_orig is not None and orig_idx != last_orig and bucket: + word_embs.append(self._aggregate(bucket)) + bucket = [] + bucket.append(emb) + last_orig = orig_idx + + if bucket: + word_embs.append(self._aggregate(bucket)) + + all_token_embs.append( + torch.stack(word_embs) if word_embs else torch.empty(0, embs.shape[-1], device=embs.device) + ) + all_schema_embs.append(schema_embs) + + return all_token_embs, all_schema_embs + + def _aggregate(self, pieces: List[torch.Tensor]) -> torch.Tensor: + """Aggregate subword embeddings.""" + # OPT-10: Short-circuit for single subword tokens (common case) + if len(pieces) == 1: + return pieces[0] + if self.token_pooling == "first": + return pieces[0] + stack = torch.stack(pieces) + if self.token_pooling == "mean": + return stack.mean(dim=0) + if self.token_pooling == "max": + return stack.max(dim=0).values + return pieces[0] \ No newline at end of file diff --git a/packages/GLiNER2/gliner2/training/__init__.py b/packages/GLiNER2/gliner2/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/GLiNER2/gliner2/training/data.py b/packages/GLiNER2/gliner2/training/data.py new file mode 100644 index 0000000..24e23d2 --- /dev/null +++ b/packages/GLiNER2/gliner2/training/data.py @@ -0,0 +1,1277 @@ +""" +GLiNER2 Training Data Creation & Validation Module + +This module provides intuitive classes for creating, validating, and managing +training data for GLiNER2 models. + +Quick Examples +-------------- +Create entity examples: + >>> example = InputExample( + ... text="John works at Google in NYC.", + ... entities={"person": ["John"], "company": ["Google"], "location": ["NYC"]} + ... ) + +Create classification examples: + >>> example = InputExample( + ... text="This movie is amazing!", + ... classifications=[ + ... Classification(task="sentiment", labels=["positive", "negative"], true_label="positive") + ... ] + ... ) + +Create structured data examples: + >>> example = InputExample( + ... text="iPhone 15 costs $999", + ... structures=[ + ... Structure("product", name="iPhone 15", price="$999") + ... ] + ... ) + +Create relation examples: + >>> example = InputExample( + ... text="Elon Musk founded SpaceX.", + ... relations=[ + ... Relation("founded", head="Elon Musk", tail="SpaceX") + ... ] + ... ) + +Build and validate dataset: + >>> dataset = TrainingDataset(examples) + >>> dataset.validate() # Raises ValidationError if invalid + >>> dataset.save("train.jsonl") + +Load from JSONL: + >>> dataset = TrainingDataset.load("train.jsonl") + >>> # Or load multiple files + >>> dataset = TrainingDataset.load(["train1.jsonl", "train2.jsonl"]) +""" + +from __future__ import annotations + +import json +import random +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Union, Tuple, Iterator, TYPE_CHECKING +from collections import Counter +from tqdm import tqdm + +if TYPE_CHECKING: + # Forward declarations for type checking only + pass + + +class ValidationError(Exception): + """Raised when training data validation fails.""" + + def __init__(self, message: str, errors: List[str] = None): + super().__init__(message) + self.errors = errors or [] + + def __str__(self): + if self.errors: + error_list = "\n - ".join(self.errors[:10]) + suffix = f"\n ... and {len(self.errors) - 10} more errors" if len(self.errors) > 10 else "" + return f"{self.args[0]}\n - {error_list}{suffix}" + return self.args[0] + + +# ============================================================================= +# Data Format Detection & Loading +# ============================================================================= + +class DataFormat: + """Enum-like class for supported data formats.""" + JSONL = "jsonl" + JSONL_LIST = "jsonl_list" + INPUT_EXAMPLE_LIST = "input_example_list" + TRAINING_DATASET = "training_dataset" + DICT_LIST = "dict_list" + EXTRACTOR_DATASET = "extractor_dataset" + + +def detect_data_format(data: Any) -> str: + """ + Detect the format of input data. + + Parameters + ---------- + data : Any + Input data in any supported format. + + Returns + ------- + str + The detected format name from DataFormat. + """ + # String path + if isinstance(data, str): + return DataFormat.JSONL + + # Path object + if isinstance(data, Path): + return DataFormat.JSONL + + # List types + if isinstance(data, list) and len(data) > 0: + first = data[0] + if isinstance(first, (str, Path)): + return DataFormat.JSONL_LIST + if isinstance(first, InputExample): + return DataFormat.INPUT_EXAMPLE_LIST + if isinstance(first, dict): + return DataFormat.DICT_LIST + + # Empty list - default to dict list + if isinstance(data, list) and len(data) == 0: + return DataFormat.DICT_LIST + + # TrainingDataset + if isinstance(data, TrainingDataset): + return DataFormat.TRAINING_DATASET + + # ExtractorDataset (internal) - forward reference + if type(data).__name__ == 'ExtractorDataset': + return DataFormat.EXTRACTOR_DATASET + + raise ValueError(f"Unsupported data format: {type(data)}") + + +class DataLoader_Factory: + """ + Factory for loading data from various formats into a unified internal format. + + All loaders convert data to List[Dict] format where each dict has: + - "input": str (the text) + - "output": Dict (the schema/annotations) + + Or alternatively: + - "text": str + - "schema": Dict + """ + + @staticmethod + def load( + data: Any, + max_samples: int = -1, + shuffle: bool = True, + seed: int = 42, + validate: bool = False, + ) -> List[Dict[str, Any]]: + """ + Load data from any supported format. + + Parameters + ---------- + data : Any + Input data in any supported format. + max_samples : int, default=-1 + Maximum samples to load (-1 = all). + shuffle : bool, default=True + Whether to shuffle the data. + seed : int, default=42 + Random seed for shuffling. + validate : bool, default=False + Whether to validate the data. Validation is always strict: + checks that entity spans, relation values, and structure + field values exist in the text. + + Returns + ------- + List[Dict[str, Any]] + List of records in unified format. + """ + fmt = detect_data_format(data) + + # Load based on format + if fmt == DataFormat.JSONL: + records = DataLoader_Factory._load_jsonl(data) + elif fmt == DataFormat.JSONL_LIST: + records = DataLoader_Factory._load_jsonl_list(data) + elif fmt == DataFormat.INPUT_EXAMPLE_LIST: + records = DataLoader_Factory._load_input_examples(data) + elif fmt == DataFormat.TRAINING_DATASET: + records = DataLoader_Factory._load_training_dataset(data) + elif fmt == DataFormat.DICT_LIST: + records = DataLoader_Factory._load_dict_list(data) + elif fmt == DataFormat.EXTRACTOR_DATASET: + records = data.data.copy() + else: + raise ValueError(f"Unsupported data format: {type(data)}") + + # Validate if requested + if validate and records: + valid_indices, invalid_info = DataLoader_Factory._validate_records(records) + + if invalid_info: + total_records = len(records) + num_invalid = len(invalid_info) + num_valid = len(valid_indices) + + print(f"\nValidation: Found {num_invalid} invalid record(s) out of {total_records} total") + print("Removed invalid records:") + + # Print first 5 invalid records + for idx, (record_idx, record, errors) in enumerate(invalid_info[:5]): + # Print first error for this record + error_msg = errors[0] if errors else "Unknown error" + print(f" Record {record_idx}: {error_msg}") + + if num_invalid > 5: + print(f" ... and {num_invalid - 5} more invalid record(s)") + + print(f"Kept {num_valid} valid record(s)\n") + + # Filter records to keep only valid ones + records = [records[i] for i in valid_indices] + + # Shuffle + if shuffle and records: + random.seed(seed) + random.shuffle(records) + + # Limit samples + if max_samples > 0 and len(records) > max_samples: + records = records[:max_samples] + + return records + + @staticmethod + def _load_jsonl(path: Union[str, Path]) -> List[Dict]: + """Load from single JSONL file.""" + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + records = [] + with open(path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if line: + try: + records.append(json.loads(line)) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in {path} line {line_num}: {e}") + + return records + + @staticmethod + def _load_jsonl_list(paths: List[Union[str, Path]]) -> List[Dict]: + """Load from multiple JSONL files.""" + records = [] + for path in paths: + records.extend(DataLoader_Factory._load_jsonl(path)) + return records + + @staticmethod + def _load_input_examples(examples: List[InputExample]) -> List[Dict]: + """Load from list of InputExample objects.""" + return [ex.to_dict() for ex in examples] + + @staticmethod + def _load_training_dataset(dataset: TrainingDataset) -> List[Dict]: + """Load from TrainingDataset object.""" + return dataset.to_records() + + @staticmethod + def _load_dict_list(dicts: List[Dict]) -> List[Dict]: + """Load from list of dicts.""" + if not dicts: + return [] + + first = dicts[0] + + # Check format + if "input" in first and "output" in first: + # Already in correct format + return dicts + elif "text" in first and "schema" in first: + # Alternative format - keep as is (handled in __getitem__) + return dicts + elif "text" in first: + # Maybe has entities/classifications at top level - try to convert + records = [] + for d in dicts: + output = {} + if "entities" in d: + output["entities"] = d["entities"] + if "classifications" in d: + output["classifications"] = d["classifications"] + if "relations" in d: + output["relations"] = d["relations"] + if "json_structures" in d: + output["json_structures"] = d["json_structures"] + records.append({"input": d["text"], "output": output}) + return records + else: + raise ValueError( + f"Unknown dict format. Expected keys like 'input'/'output', 'text'/'schema', " + f"or 'text' with annotation keys. Got: {list(first.keys())}" + ) + + @staticmethod + def _validate_records(records: List[Dict]) -> Tuple[List[int], List[Tuple[int, Dict, List[str]]]]: + """ + Validate and sanitize all records, removing only invalid parts. + + Uses granular sanitization: + - Entities: Drop entity type if any mention not found + - Classifications: Drop individual invalid classifications + - Structures: Remove invalid fields, drop if all invalid + - Relations: Drop relation if any field invalid + - Record: Drop only if no valid tasks remain + + Returns + ------- + Tuple[List[int], List[Tuple[int, Dict, List[str]]]] + - First element: List of valid record indices (sanitized records replace originals) + - Second element: List of (index, original_record, warning_messages) for dropped records + """ + valid_indices = [] + invalid_info = [] + + for i, record in tqdm(enumerate(records), total=len(records), desc="Validating records", unit="record"): + warnings = [] + try: + example = InputExample.from_dict(record) + sanitize_warnings, is_valid = example.sanitize() + + if is_valid: + # Replace record with sanitized version + records[i] = example.to_dict() + valid_indices.append(i) + if sanitize_warnings: + # Record was sanitized but still valid + warnings.extend(sanitize_warnings) + else: + # No valid content remains + warnings.extend(sanitize_warnings) + invalid_info.append((i, record, warnings)) + except Exception as e: + warnings.append(f"Failed to parse - {e}") + invalid_info.append((i, record, warnings)) + + return valid_indices, invalid_info + + +# Type alias for flexible data input +TrainDataInput = Union[ + str, # Single JSONL path + Path, # Single JSONL path + List[str], # Multiple JSONL paths + List[Path], # Multiple JSONL paths + List[Dict[str, Any]], # Raw records + 'TrainingDataset', # TrainingDataset (forward reference) + 'List[InputExample]', # List of InputExample (forward reference) + 'ExtractorDataset', # Legacy dataset (forward reference) +] + + +# ============================================================================= +# Training Data Classes +# ============================================================================= + +@dataclass +class Classification: + """ + A classification task definition. + + Parameters + ---------- + task : str + Name of the classification task (e.g., "sentiment", "category"). + labels : List[str] + All possible labels for this task. + true_label : str or List[str] + The correct label(s) for this example. + multi_label : bool, default=False + Whether multiple labels can be selected. + prompt : str, optional + Custom prompt for the task. + examples : List[Tuple[str, str]], optional + Few-shot examples as (input, output) pairs. + label_descriptions : Dict[str, str], optional + Descriptions for each label. + """ + task: str + labels: List[str] + true_label: Union[str, List[str]] + multi_label: bool = False + prompt: Optional[str] = None + examples: Optional[List[Tuple[str, str]]] = None + label_descriptions: Optional[Dict[str, str]] = None + + def __post_init__(self): + if isinstance(self.true_label, str): + self._true_label_list = [self.true_label] + else: + self._true_label_list = list(self.true_label) + + # Auto-infer multi_label=True when multiple true labels are provided + if len(self._true_label_list) > 1: + self.multi_label = True + + def validate(self) -> List[str]: + """Validate this classification and return list of errors.""" + errors = [] + if not self.task: + errors.append("Classification task name cannot be empty") + if not self.labels: + errors.append(f"Classification '{self.task}' has no labels") + for label in self._true_label_list: + if label not in self.labels: + errors.append(f"True label '{label}' not in labels list for task '{self.task}'") + if len(self._true_label_list) > 1 and not self.multi_label: + errors.append(f"Multiple true labels provided for '{self.task}' but multi_label=False") + if self.label_descriptions: + for key in self.label_descriptions: + if key not in self.labels: + errors.append(f"Label description key '{key}' not in labels for task '{self.task}'") + if self.examples: + for i, ex in enumerate(self.examples): + if not isinstance(ex, (list, tuple)) or len(ex) != 2: + errors.append(f"Example {i} for task '{self.task}' must be (input, output) pair") + return errors + + def to_dict(self) -> Dict[str, Any]: + """Convert to training format dictionary.""" + result = {"task": self.task, "labels": self.labels, "true_label": self._true_label_list} + if self.multi_label: + result["multi_label"] = True + if self.prompt: + result["prompt"] = self.prompt + if self.examples: + result["examples"] = [list(ex) for ex in self.examples] + if self.label_descriptions: + result["label_descriptions"] = self.label_descriptions + return result + + +@dataclass +class ChoiceField: + """ + A field with predefined choices (classification within structure). + + Parameters + ---------- + value : str + The selected value. + choices : List[str] + All possible choices. + """ + value: str + choices: List[str] + + def validate(self, field_name: str) -> List[str]: + errors = [] + if self.value not in self.choices: + errors.append(f"Choice value '{self.value}' not in choices {self.choices} for field '{field_name}'") + return errors + + def to_dict(self) -> Dict[str, Any]: + return {"value": self.value, "choices": self.choices} + + +@dataclass +class Structure: + """ + A structured data extraction definition. + + Parameters + ---------- + struct_name : str + Name of the structure (e.g., "product", "contact"). + **fields : Any + Field names and values. Values can be: + - str: Single string value + - List[str]: Multiple values + - ChoiceField: Classification-style field with choices + + Examples + -------- + >>> struct = Structure("product", name="iPhone", price="$999") + >>> struct = Structure("contact", name="John", email="john@example.com") + """ + struct_name: str + _fields: Dict[str, Any] = field(default_factory=dict) + descriptions: Optional[Dict[str, str]] = None + + def __init__(self, struct_name: str, _descriptions: Dict[str, str] = None, **fields): + self.struct_name = struct_name + self._fields = fields + self.descriptions = _descriptions + + def validate(self, text: str) -> List[str]: + """ + Validate this structure. + + Parameters + ---------- + text : str + The text to validate against. Field values must exist in this text. + + Returns + ------- + List[str] + List of validation errors. + """ + errors = [] + if not self.struct_name: + errors.append("Structure name cannot be empty") + if not self._fields: + errors.append(f"Structure '{self.struct_name}' has no fields") + for field_name, value in self._fields.items(): + if isinstance(value, ChoiceField): + errors.extend(value.validate(f"{self.struct_name}.{field_name}")) + elif isinstance(value, list): + for i, v in enumerate(value): + if v and v.lower() not in text.lower(): + errors.append(f"List value '{v}' at index {i} in '{self.struct_name}.{field_name}' not found in text") + elif isinstance(value, str): + if value and value.lower() not in text.lower(): + errors.append(f"Value '{value}' for '{self.struct_name}.{field_name}' not found in text") + return errors + + def to_dict(self) -> Dict[str, Dict[str, Any]]: + fields_dict = {} + for field_name, value in self._fields.items(): + if isinstance(value, ChoiceField): + fields_dict[field_name] = value.to_dict() + else: + fields_dict[field_name] = value + return {self.struct_name: fields_dict} + + def get_descriptions(self) -> Optional[Dict[str, Dict[str, str]]]: + if self.descriptions: + return {self.struct_name: self.descriptions} + return None + + +@dataclass +class Relation: + """ + A relation extraction definition. + + Parameters + ---------- + name : str + Name of the relation (e.g., "works_for", "founded"). + head : str, optional + The source/subject entity. + tail : str, optional + The target/object entity. + **fields : Any + Custom field names and values (use instead of head/tail). + """ + name: str + head: Optional[str] = None + tail: Optional[str] = None + _fields: Dict[str, str] = field(default_factory=dict) + + def __init__(self, name: str, head: str = None, tail: str = None, **fields): + self.name = name + self.head = head + self.tail = tail + if fields: + self._fields = fields + elif head is not None and tail is not None: + self._fields = {"head": head, "tail": tail} + else: + self._fields = {} + if head is not None: + self._fields["head"] = head + if tail is not None: + self._fields["tail"] = tail + + def validate(self, text: str) -> List[str]: + """ + Validate this relation. + + Parameters + ---------- + text : str + The text to validate against. Field values must exist in this text. + + Returns + ------- + List[str] + List of validation errors. + """ + errors = [] + if not self.name: + errors.append("Relation name cannot be empty") + if not self._fields: + errors.append(f"Relation '{self.name}' has no fields") + for field_name, value in self._fields.items(): + if isinstance(value, str) and value: + if value.lower() not in text.lower(): + errors.append(f"Relation value '{value}' for '{self.name}.{field_name}' not found in text") + return errors + + def get_field_names(self) -> List[str]: + return list(self._fields.keys()) + + def to_dict(self) -> Dict[str, Dict[str, str]]: + return {self.name: self._fields} + + +@dataclass +class InputExample: + """ + A single training example for GLiNER2. + + Parameters + ---------- + text : str + The input text for this example. + entities : Dict[str, List[str]], optional + Entity type to mentions mapping. + entity_descriptions : Dict[str, str], optional + Descriptions for entity types. + classifications : List[Classification], optional + Classification tasks for this example. + structures : List[Structure], optional + Structured data extractions for this example. + relations : List[Relation], optional + Relation extractions for this example. + + Examples + -------- + >>> example = InputExample( + ... text="John Smith works at Google.", + ... entities={"person": ["John Smith"], "company": ["Google"]} + ... ) + """ + text: str + entities: Optional[Dict[str, List[str]]] = None + entity_descriptions: Optional[Dict[str, str]] = None + classifications: Optional[List[Classification]] = None + structures: Optional[List[Structure]] = None + relations: Optional[List[Relation]] = None + + def __post_init__(self): + if self.entities is None: + self.entities = {} + if self.classifications is None: + self.classifications = [] + if self.structures is None: + self.structures = [] + if self.relations is None: + self.relations = [] + + def validate(self) -> List[str]: + """ + Validate this example. + + Validation is always strict: checks that entity mentions, relation values, + and structure field values exist in the text (case-insensitive). + + Returns + ------- + List[str] + List of validation errors. Empty list means valid. + """ + errors = [] + if not self.text or not self.text.strip(): + errors.append("Text cannot be empty") + return errors + + if self.entities: + for entity_type, mentions in self.entities.items(): + if not entity_type: + errors.append("Entity type cannot be empty") + for mention in mentions: + if mention and mention.lower() not in self.text.lower(): + errors.append(f"Entity '{mention}' (type: {entity_type}) not found in text") + + if self.entity_descriptions and self.entities: + for desc_type in self.entity_descriptions: + if desc_type not in self.entities: + errors.append(f"Entity description for '{desc_type}' but no entities of that type") + + for cls in self.classifications: + errors.extend(cls.validate()) + + for struct in self.structures: + errors.extend(struct.validate(self.text)) + + relation_fields = {} + for rel in self.relations: + errors.extend(rel.validate(self.text)) + field_names = tuple(sorted(rel.get_field_names())) + if rel.name in relation_fields: + if relation_fields[rel.name] != field_names: + errors.append(f"Relation '{rel.name}' has inconsistent fields: {relation_fields[rel.name]} vs {field_names}") + else: + relation_fields[rel.name] = field_names + + has_content = bool(self.entities) or bool(self.classifications) or bool(self.structures) or bool(self.relations) + if not has_content: + errors.append("Example must have at least one task (entities, classifications, structures, or relations)") + + return errors + + def is_valid(self) -> bool: + """Check if this example is valid.""" + return len(self.validate()) == 0 + + def sanitize(self) -> Tuple[List[str], bool]: + """ + Remove invalid parts from this example, keeping only valid content. + Mutates self in-place. + + Granular removal strategy: + - Entities: Drop entire entity type if ANY mention is not found in text + - Classifications: Drop individual classifications that have errors + - Structures: Remove invalid fields; drop structure only if ALL fields become invalid + - Relations: Drop the specific relation if ANY field has an error + - Example: Mark as invalid only if no valid tasks remain + + Returns + ------- + Tuple[List[str], bool] + - List of warning messages about what was removed + - bool: True if example still has valid content, False if should be dropped + """ + warnings = [] + + if not self.text or not self.text.strip(): + warnings.append("Text is empty") + return warnings, False + + # 1. Sanitize entities - drop entity type if any mention not found + if self.entities: + types_to_remove = [] + for entity_type, mentions in self.entities.items(): + if not entity_type: + types_to_remove.append(entity_type) + warnings.append(f"Entity type is empty") + continue + + # Check if any mention is not in text + has_invalid = False + for mention in mentions: + if mention and mention.lower() not in self.text.lower(): + has_invalid = True + warnings.append(f"Entity '{mention}' (type: {entity_type}) not found in text - dropping entity type") + break + + if has_invalid: + types_to_remove.append(entity_type) + + # Remove invalid entity types + for entity_type in types_to_remove: + del self.entities[entity_type] + + # Clean up entity descriptions for removed types + if self.entity_descriptions: + desc_to_remove = [desc_type for desc_type in self.entity_descriptions if desc_type not in self.entities] + for desc_type in desc_to_remove: + del self.entity_descriptions[desc_type] + + # 2. Sanitize classifications - drop individual invalid ones + if self.classifications: + valid_classifications = [] + for cls in self.classifications: + cls_errors = cls.validate() + if cls_errors: + warnings.append(f"Classification '{cls.task}' has errors - dropping: {cls_errors[0]}") + else: + valid_classifications.append(cls) + self.classifications = valid_classifications + + # 3. Sanitize structures - remove invalid fields, drop if all invalid + if self.structures: + valid_structures = [] + for struct in self.structures: + if not struct.struct_name: + warnings.append(f"Structure has empty name - dropping") + continue + + if not struct._fields: + warnings.append(f"Structure '{struct.struct_name}' has no fields - dropping") + continue + + # Filter out invalid fields + valid_fields = {} + for field_name, value in struct._fields.items(): + is_valid = True + + if isinstance(value, ChoiceField): + field_errors = value.validate(f"{struct.struct_name}.{field_name}") + if field_errors: + warnings.append(f"Field '{struct.struct_name}.{field_name}' invalid - dropping field") + is_valid = False + elif isinstance(value, list): + for v in value: + if v and v.lower() not in self.text.lower(): + warnings.append(f"List value '{v}' in '{struct.struct_name}.{field_name}' not found - dropping field") + is_valid = False + break + elif isinstance(value, str): + if value and value.lower() not in self.text.lower(): + warnings.append(f"Value '{value}' for '{struct.struct_name}.{field_name}' not found - dropping field") + is_valid = False + + if is_valid: + valid_fields[field_name] = value + + # Only keep structure if it has at least one valid field + if valid_fields: + struct._fields = valid_fields + valid_structures.append(struct) + else: + warnings.append(f"Structure '{struct.struct_name}' has no valid fields - dropping") + + self.structures = valid_structures + + # 4. Sanitize relations - drop entire relation if any field is invalid + if self.relations: + valid_relations = [] + for rel in self.relations: + if not rel.name: + warnings.append(f"Relation has empty name - dropping") + continue + + if not rel._fields: + warnings.append(f"Relation '{rel.name}' has no fields - dropping") + continue + + # Check if any field value is invalid + has_invalid = False + for field_name, value in rel._fields.items(): + if isinstance(value, str) and value: + if value.lower() not in self.text.lower(): + warnings.append(f"Relation '{rel.name}' field '{field_name}' value '{value}' not found - dropping relation") + has_invalid = True + break + + if not has_invalid: + valid_relations.append(rel) + + self.relations = valid_relations + + # Check if example still has any valid content + has_content = bool(self.entities) or bool(self.classifications) or bool(self.structures) or bool(self.relations) + + if not has_content: + warnings.append("No valid tasks remain after sanitization") + return warnings, False + + return warnings, True + + def to_dict(self) -> Dict[str, Any]: + """Convert to GLiNER2 training format.""" + output = {} + if self.entities: + output["entities"] = self.entities + if self.entity_descriptions: + output["entity_descriptions"] = self.entity_descriptions + if self.classifications: + output["classifications"] = [cls.to_dict() for cls in self.classifications] + if self.structures: + output["json_structures"] = [struct.to_dict() for struct in self.structures] + all_descriptions = {} + for struct in self.structures: + desc = struct.get_descriptions() + if desc: + all_descriptions.update(desc) + if all_descriptions: + output["json_descriptions"] = all_descriptions + if self.relations: + output["relations"] = [rel.to_dict() for rel in self.relations] + return {"input": self.text, "output": output} + + def to_json(self) -> str: + return json.dumps(self.to_dict(), ensure_ascii=False) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'InputExample': + """Create InputExample from training format dictionary.""" + text = data["input"] + output = data["output"] + + entities = output.get("entities") + entity_descriptions = output.get("entity_descriptions") + + classifications = [] + for cls_data in output.get("classifications", []): + classifications.append(Classification( + task=cls_data["task"], + labels=cls_data["labels"], + true_label=cls_data["true_label"], + multi_label=cls_data.get("multi_label", False), + prompt=cls_data.get("prompt"), + examples=[tuple(ex) for ex in cls_data.get("examples", [])] or None, + label_descriptions=cls_data.get("label_descriptions") + )) + + structures = [] + json_descriptions = output.get("json_descriptions", {}) + for struct_data in output.get("json_structures", []): + for struct_name, fields in struct_data.items(): + parsed_fields = {} + for field_name, value in fields.items(): + if isinstance(value, dict) and "value" in value and "choices" in value: + parsed_fields[field_name] = ChoiceField(value["value"], value["choices"]) + else: + parsed_fields[field_name] = value + structures.append(Structure(struct_name, _descriptions=json_descriptions.get(struct_name), **parsed_fields)) + + relations = [] + for rel_data in output.get("relations", []): + for rel_name, fields in rel_data.items(): + if "head" in fields and "tail" in fields and len(fields) == 2: + relations.append(Relation(rel_name, head=fields["head"], tail=fields["tail"])) + else: + relations.append(Relation(rel_name, **fields)) + + return cls( + text=text, + entities=entities, + entity_descriptions=entity_descriptions, + classifications=classifications if classifications else None, + structures=structures if structures else None, + relations=relations if relations else None + ) + + @classmethod + def from_json(cls, json_str: str) -> 'InputExample': + return cls.from_dict(json.loads(json_str)) + + +class TrainingDataset: + """ + A collection of InputExamples for training GLiNER2. + + Can be created from: + - List of InputExample objects + - JSONL file path(s) + - Raw dict data + + Parameters + ---------- + examples : List[InputExample], optional + Initial list of examples. + + Examples + -------- + >>> # From InputExample list + >>> dataset = TrainingDataset([example1, example2]) + >>> + >>> # From JSONL file + >>> dataset = TrainingDataset.load("train.jsonl") + >>> + >>> # From multiple JSONL files + >>> dataset = TrainingDataset.load(["train1.jsonl", "train2.jsonl"]) + """ + + def __init__(self, examples: List[InputExample] = None): + self.examples: List[InputExample] = examples or [] + + def __len__(self) -> int: + return len(self.examples) + + def __getitem__(self, idx: int) -> InputExample: + return self.examples[idx] + + def __iter__(self) -> Iterator[InputExample]: + return iter(self.examples) + + def add(self, example: InputExample) -> 'TrainingDataset': + self.examples.append(example) + return self + + def add_many(self, examples: List[InputExample]) -> 'TrainingDataset': + self.examples.extend(examples) + return self + + def validate(self, raise_on_error: bool = True) -> Dict[str, Any]: + """ + Validate all examples in the dataset. + + Validation is always strict: checks that entity mentions, relation values, + and structure field values exist in the text (case-insensitive). + + Parameters + ---------- + raise_on_error : bool, default=True + If True, raises ValidationError when invalid examples are found. + If False, returns validation report without raising. + + Returns + ------- + Dict[str, Any] + Validation report with counts and error details. + """ + all_errors = [] + valid_count = 0 + invalid_indices = [] + + for i, example in enumerate(self.examples): + errors = example.validate() + if errors: + invalid_indices.append(i) + for error in errors: + all_errors.append(f"Example {i}: {error}") + else: + valid_count += 1 + + report = { + "valid": valid_count, + "invalid": len(invalid_indices), + "total": len(self.examples), + "invalid_indices": invalid_indices, + "errors": all_errors + } + + if all_errors and raise_on_error: + raise ValidationError(f"Dataset validation failed: {len(invalid_indices)} invalid examples", all_errors) + + return report + + def validate_relation_consistency(self) -> List[str]: + """Validate that relation field structures are consistent across the dataset.""" + errors = [] + relation_fields: Dict[str, Tuple[int, Tuple[str, ...]]] = {} + + for i, example in enumerate(self.examples): + for rel in example.relations: + field_names = tuple(sorted(rel.get_field_names())) + if rel.name in relation_fields: + first_idx, first_fields = relation_fields[rel.name] + if first_fields != field_names: + errors.append(f"Relation '{rel.name}' field inconsistency: Example {first_idx} has {list(first_fields)}, but Example {i} has {list(field_names)}") + else: + relation_fields[rel.name] = (i, field_names) + return errors + + def stats(self) -> Dict[str, Any]: + """Get statistics about the dataset.""" + stats = { + "total_examples": len(self.examples), + "entity_types": Counter(), + "entity_mentions": 0, + "classification_tasks": Counter(), + "classification_labels": {}, + "structure_types": Counter(), + "relation_types": Counter(), + "text_lengths": [], + "task_distribution": { + "entities_only": 0, "classifications_only": 0, "structures_only": 0, + "relations_only": 0, "multi_task": 0, "empty": 0 + } + } + + for example in self.examples: + stats["text_lengths"].append(len(example.text)) + for entity_type, mentions in example.entities.items(): + stats["entity_types"][entity_type] += len(mentions) + stats["entity_mentions"] += len(mentions) + for cls in example.classifications: + stats["classification_tasks"][cls.task] += 1 + if cls.task not in stats["classification_labels"]: + stats["classification_labels"][cls.task] = Counter() + for label in cls._true_label_list: + stats["classification_labels"][cls.task][label] += 1 + for struct in example.structures: + stats["structure_types"][struct.struct_name] += 1 + for rel in example.relations: + stats["relation_types"][rel.name] += 1 + + has_entities = bool(example.entities) + has_cls = bool(example.classifications) + has_struct = bool(example.structures) + has_rel = bool(example.relations) + task_count = sum([has_entities, has_cls, has_struct, has_rel]) + + if task_count == 0: + stats["task_distribution"]["empty"] += 1 + elif task_count > 1: + stats["task_distribution"]["multi_task"] += 1 + elif has_entities: + stats["task_distribution"]["entities_only"] += 1 + elif has_cls: + stats["task_distribution"]["classifications_only"] += 1 + elif has_struct: + stats["task_distribution"]["structures_only"] += 1 + elif has_rel: + stats["task_distribution"]["relations_only"] += 1 + + if stats["text_lengths"]: + lengths = stats["text_lengths"] + stats["text_length_stats"] = { + "min": min(lengths), "max": max(lengths), + "mean": sum(lengths) / len(lengths), + "median": sorted(lengths)[len(lengths) // 2] + } + + stats["entity_types"] = dict(stats["entity_types"]) + stats["classification_tasks"] = dict(stats["classification_tasks"]) + stats["classification_labels"] = {k: dict(v) for k, v in stats["classification_labels"].items()} + stats["structure_types"] = dict(stats["structure_types"]) + stats["relation_types"] = dict(stats["relation_types"]) + + return stats + + def print_stats(self): + """Print formatted statistics.""" + s = self.stats() + print(f"\n{'='*60}") + print(f"GLiNER2 Training Dataset Statistics") + print(f"{'='*60}") + print(f"Total examples: {s['total_examples']}") + + if s.get('text_length_stats'): + tls = s['text_length_stats'] + print(f"\nText lengths: min={tls['min']}, max={tls['max']}, mean={tls['mean']:.1f}") + + print(f"\nTask Distribution:") + for task, count in s['task_distribution'].items(): + if count > 0: + print(f" {task}: {count} ({100*count/s['total_examples']:.1f}%)") + + if s['entity_types']: + print(f"\nEntity Types ({s['entity_mentions']} total mentions):") + for etype, count in sorted(s['entity_types'].items(), key=lambda x: -x[1]): + print(f" {etype}: {count}") + + if s['classification_tasks']: + print(f"\nClassification Tasks:") + for task, count in s['classification_tasks'].items(): + print(f" {task}: {count} examples") + if task in s['classification_labels']: + for label, lcount in s['classification_labels'][task].items(): + print(f" - {label}: {lcount}") + + if s['structure_types']: + print(f"\nStructure Types:") + for stype, count in s['structure_types'].items(): + print(f" {stype}: {count}") + + if s['relation_types']: + print(f"\nRelation Types:") + for rtype, count in s['relation_types'].items(): + print(f" {rtype}: {count}") + + print(f"{'='*60}\n") + + def to_jsonl(self) -> str: + return "\n".join(example.to_json() for example in self.examples) + + def to_records(self) -> List[Dict[str, Any]]: + """Convert to list of record dicts for trainer.""" + return [ex.to_dict() for ex in self.examples] + + def save(self, path: Union[str, Path], validate_first: bool = True): + """Save dataset to JSONL file.""" + if validate_first: + self.validate() + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, 'w', encoding='utf-8') as f: + for example in self.examples: + f.write(example.to_json() + '\n') + print(f"Saved {len(self.examples)} examples to {path}") + + @classmethod + def load(cls, paths: Union[str, Path, List[Union[str, Path]]], shuffle: bool = False, seed: int = 42) -> 'TrainingDataset': + """ + Load dataset from JSONL file(s). + + Parameters + ---------- + paths : str, Path, or List + Single file path or list of file paths. + shuffle : bool, default=False + Whether to shuffle the loaded examples. + seed : int, default=42 + Random seed for shuffling. + + Returns + ------- + TrainingDataset + """ + if isinstance(paths, (str, Path)): + paths = [paths] + + examples = [] + for path in paths: + path = Path(path) + with open(path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if line: + try: + data = json.loads(line) + examples.append(InputExample.from_dict(data)) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in {path} line {line_num}: {e}") + except Exception as e: + raise ValueError(f"Error parsing {path} line {line_num}: {e}") + print(f"Loaded {len(examples)} examples from {path}") + + if shuffle: + random.seed(seed) + random.shuffle(examples) + + return cls(examples) + + @classmethod + def from_records(cls, records: List[Dict[str, Any]]) -> 'TrainingDataset': + """Create dataset from list of record dicts.""" + examples = [InputExample.from_dict(r) for r in records] + return cls(examples) + + def split(self, train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, + shuffle: bool = True, seed: int = 42) -> Tuple['TrainingDataset', 'TrainingDataset', 'TrainingDataset']: + """Split dataset into train/val/test sets.""" + if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6: + raise ValueError("Ratios must sum to 1.0") + + indices = list(range(len(self.examples))) + if shuffle: + random.seed(seed) + random.shuffle(indices) + + n = len(indices) + train_end = int(n * train_ratio) + val_end = train_end + int(n * val_ratio) + + return ( + TrainingDataset([self.examples[i] for i in indices[:train_end]]), + TrainingDataset([self.examples[i] for i in indices[train_end:val_end]]), + TrainingDataset([self.examples[i] for i in indices[val_end:]]) + ) + + def filter(self, predicate) -> 'TrainingDataset': + """Filter examples based on a predicate function.""" + return TrainingDataset([ex for ex in self.examples if predicate(ex)]) + + def sample(self, n: int, seed: int = 42) -> 'TrainingDataset': + """Random sample of examples.""" + random.seed(seed) + return TrainingDataset(random.sample(self.examples, min(n, len(self.examples)))) + + +# Convenience functions +def create_entity_example(text: str, entities: Dict[str, List[str]], descriptions: Dict[str, str] = None) -> InputExample: + """Create an entity extraction example.""" + return InputExample(text=text, entities=entities, entity_descriptions=descriptions) + + +def create_classification_example(text: str, task: str, labels: List[str], true_label: Union[str, List[str]], + multi_label: bool = False, **kwargs) -> InputExample: + """Create a classification example.""" + return InputExample(text=text, classifications=[Classification(task=task, labels=labels, true_label=true_label, multi_label=multi_label, **kwargs)]) + + +def create_structure_example(text: str, structure_name: str, **fields) -> InputExample: + """Create a structured data example.""" + return InputExample(text=text, structures=[Structure(structure_name, **fields)]) + + +def create_relation_example(text: str, relation_name: str, head: str = None, tail: str = None, **fields) -> InputExample: + """Create a relation extraction example.""" + return InputExample(text=text, relations=[Relation(relation_name, head=head, tail=tail, **fields)]) \ No newline at end of file diff --git a/packages/GLiNER2/gliner2/training/lora.py b/packages/GLiNER2/gliner2/training/lora.py new file mode 100644 index 0000000..8777ead --- /dev/null +++ b/packages/GLiNER2/gliner2/training/lora.py @@ -0,0 +1,836 @@ +""" +Custom LoRA (Low-Rank Adaptation) Implementation for GLiNER2 +============================================================= + +Parameter-efficient fine-tuning by injecting trainable low-rank matrices +into frozen linear layers of the encoder. + +Based on: "LoRA: Low-Rank Adaptation of Large Language Models" +Paper: https://arxiv.org/abs/2106.09685 +""" + +from __future__ import annotations + +import json +import logging +import math +from dataclasses import dataclass, field, asdict +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +from safetensors.torch import save_file, load_file + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# LoRA Configuration +# ============================================================================= + +@dataclass +class LoRAConfig: + """ + Configuration for LoRA parameter-efficient fine-tuning. + + Parameters + ---------- + enabled : bool + Whether LoRA is enabled. + r : int + Rank of the low-rank decomposition (bottleneck dimension). + Higher r = more parameters but better approximation. + Typical values: 4, 8, 16, 32, 64. + alpha : float + Scaling factor for LoRA updates. Final scaling is alpha/r. + Typical values: 8, 16, 32 (often 2*r). + dropout : float + Dropout probability applied to LoRA path. + target_modules : List[str] + Module names to apply LoRA to. Supported module groups: + + - "encoder" - Applies LoRA to query, key, value, dense layers within encoder + - "encoder.query" - Only query layers in encoder + - "encoder.key" - Only key layers in encoder + - "encoder.value" - Only value layers in encoder + - "encoder.dense" - Only dense layers in encoder + - "span_rep" - Applies LoRA to ALL linear layers within span_rep + - "classifier" - Applies LoRA to ALL linear layers within classifier + - "count_embed" - Applies LoRA to ALL linear layers within count_embed + - "count_pred" - Applies LoRA to ALL linear layers within count_pred + + Examples: + - ["encoder"] - all encoder layers (query, key, value, dense) + - ["encoder.query", "encoder.key", "encoder.value"] - only attention layers + - ["encoder.dense"] - only dense (FFN) layers in encoder + - ["encoder", "span_rep", "classifier"] - encoder + task heads + - ["classifier"] - classifier fine-tuning only + """ + enabled: bool = False + r: int = 8 + alpha: float = 16.0 + dropout: float = 0.0 + target_modules: List[str] = field(default_factory=lambda: ["encoder"]) + + def __post_init__(self): + if self.r <= 0: + raise ValueError(f"LoRA rank must be > 0, got {self.r}") + if self.alpha <= 0: + raise ValueError(f"LoRA alpha must be > 0, got {self.alpha}") + if not 0 <= self.dropout < 1: + raise ValueError(f"LoRA dropout must be in [0, 1), got {self.dropout}") + if self.enabled and not self.target_modules: + raise ValueError("target_modules cannot be empty when LoRA is enabled") + + +@dataclass +class LoRAAdapterConfig: + """ + Configuration for a saved LoRA adapter. + + This is the config that gets saved with adapter-only checkpoints. + """ + adapter_type: str = "lora" + adapter_version: str = "1.0" + lora_r: int = 8 + lora_alpha: float = 16.0 + lora_dropout: float = 0.0 + target_modules: List[str] = field(default_factory=list) + created_at: str = "" + + def save(self, path: Union[str, Path]) -> None: + """Save adapter config to JSON file.""" + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + config_path = path / "adapter_config.json" + + # Set created_at if not set + if not self.created_at: + self.created_at = datetime.utcnow().isoformat() + "Z" + + with open(config_path, "w") as f: + json.dump(asdict(self), f, indent=2) + + logger.info(f"Saved adapter config to {config_path}") + + @classmethod + def load(cls, path: Union[str, Path]) -> 'LoRAAdapterConfig': + """Load adapter config from JSON file or directory.""" + path = Path(path) + + # If path is a directory, look for adapter_config.json + if path.is_dir(): + config_path = path / "adapter_config.json" + else: + config_path = path + + if not config_path.exists(): + raise FileNotFoundError(f"Adapter config not found at {config_path}") + + with open(config_path) as f: + config_dict = json.load(f) + + return cls(**config_dict) + + @classmethod + def is_adapter_path(cls, path: Union[str, Path]) -> bool: + """Check if path contains an adapter.""" + path = Path(path) + + # Check for adapter_config.json + if path.is_dir(): + return (path / "adapter_config.json").exists() + else: + return path.name == "adapter_config.json" and path.exists() + + +# ============================================================================= +# LoRA Layer +# ============================================================================= + +class LoRALayer(nn.Module): + """ + LoRA-enhanced Linear layer. + + Computes: output = W*x + (B*A*x) * scaling + Where: + - W is the frozen original weight + - A, B are trainable low-rank matrices + - scaling = alpha / r + + Parameters + ---------- + base_layer : nn.Linear + Original linear layer (will be frozen). + r : int + Rank of low-rank decomposition. + alpha : float + LoRA scaling factor. + dropout : float + Dropout probability. + """ + + def __init__( + self, + base_layer: nn.Linear, + r: int, + alpha: float, + dropout: float = 0.0, + ): + super().__init__() + + self.r = r + self.alpha = alpha + self.scaling = alpha / r + + in_features = base_layer.in_features + out_features = base_layer.out_features + + # Store frozen base layer + self.base_layer = base_layer + for param in self.base_layer.parameters(): + param.requires_grad = False + + # Get device from base layer to ensure LoRA parameters are on same device + device = next(base_layer.parameters()).device + + # LoRA low-rank matrices + # A: (r, in_features) - initialized with small random values + # B: (out_features, r) - initialized to zero (no change at start) + self.lora_A = nn.Parameter(torch.zeros(r, in_features, device=device)) + self.lora_B = nn.Parameter(torch.zeros(out_features, r, device=device)) + + # Initialize A with Kaiming uniform (same as nn.Linear default) + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + # B stays zero-initialized + + # Dropout + self.lora_dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity() + + # Flag to track if weights are merged + self.merged = False + + # Expose base layer attributes for compatibility + @property + def weight(self): + """Expose weight from base layer for compatibility.""" + return self.base_layer.weight + + @property + def bias(self): + """Expose bias from base layer for compatibility.""" + return self.base_layer.bias + + @property + def in_features(self): + """Expose in_features from base layer.""" + return self.base_layer.in_features + + @property + def out_features(self): + """Expose out_features from base layer.""" + return self.base_layer.out_features + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with LoRA. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (..., in_features). + + Returns + ------- + torch.Tensor + Output tensor of shape (..., out_features). + """ + # Base output from frozen weights + base_output = self.base_layer(x) + + if self.merged: + # Weights already merged, just use base layer + return base_output + + # LoRA path: x -> dropout -> A -> B -> scale + # Equivalent to: (x @ A.T) @ B.T * scaling + lora_output = self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T + + return base_output + lora_output * self.scaling + + def merge_weights(self): + """Merge LoRA weights (B @ A) into base layer weights.""" + if self.merged: + # Already merged, silently skip + return + + with torch.no_grad(): + # Compute LoRA contribution: B @ A * scaling + lora_weight = (self.lora_B @ self.lora_A) * self.scaling + # Add to base weight + self.base_layer.weight.data += lora_weight + + self.merged = True + logger.debug(f"Merged LoRA weights (r={self.r}) into base layer") + + def unmerge_weights(self): + """Separate LoRA weights from base layer (reverse of merge).""" + if not self.merged: + # Not merged, silently skip + return + + with torch.no_grad(): + # Subtract LoRA contribution + lora_weight = (self.lora_B @ self.lora_A) * self.scaling + self.base_layer.weight.data -= lora_weight + + self.merged = False + logger.debug(f"Unmerged LoRA weights (r={self.r}) from base layer") + + def extra_repr(self) -> str: + return f"r={self.r}, alpha={self.alpha}, scaling={self.scaling:.4f}, merged={self.merged}" + + +# ============================================================================= +# LoRA Application Functions +# ============================================================================= + +# Module-specific patterns for LoRA application +ENCODER_PATTERNS = ["query", "key", "value", "dense"] +ALL_LINEAR_MODULES = ["span_rep", "classifier", "count_embed", "count_pred"] + +def apply_lora_to_model( + model: nn.Module, + config: LoRAConfig, +) -> Tuple[nn.Module, Dict[str, LoRALayer]]: + """ + Apply LoRA to linear layers based on module groups in target_modules. + + Module group behavior: + - "encoder": Applies LoRA to query, key, value, dense layers within encoder + - "encoder.query": Only query layers in encoder + - "encoder.key": Only key layers in encoder + - "encoder.value": Only value layers in encoder + - "encoder.dense": Only dense layers in encoder + - "span_rep", "classifier", "count_embed", "count_pred": Applies LoRA to ALL linear layers + + Parameters + ---------- + model : nn.Module + The model to apply LoRA to. + config : LoRAConfig + LoRA configuration. + + Returns + ------- + model : nn.Module + Modified model with LoRA layers. + lora_layers : Dict[str, LoRALayer] + Dictionary mapping layer names to LoRA layers. + """ + if not config.enabled: + logger.info("LoRA is disabled, skipping application") + return model, {} + + lora_layers = {} + + def _should_apply_lora(local_name: str, full_path: str) -> bool: + """ + Check if LoRA should be applied based on module groups. + + Args: + local_name: Local module name (e.g., "query", "linear") + full_path: Full path from model root (e.g., "encoder.layer.0.attention.self.query") + + Returns: + True if LoRA should be applied to this layer + """ + for target in config.target_modules: + if target == "encoder": + # For encoder, apply only to specific patterns + if full_path.startswith("encoder."): + # Check if local name matches encoder patterns + if any(pattern in local_name for pattern in ENCODER_PATTERNS): + return True + elif target.startswith("encoder."): + # Specific encoder layer (e.g., "encoder.query", "encoder.dense") + layer_name = target.split(".", 1)[1] # Extract "query" from "encoder.query" + if full_path.startswith("encoder.") and layer_name in local_name: + return True + elif target in ALL_LINEAR_MODULES: + # For these modules, apply to ALL linear layers within + if full_path.startswith(f"{target}."): + return True + return False + + # Recursively find and replace modules + def _inject_lora_recursive(module: nn.Module, prefix: str = ""): + for name, child in module.named_children(): + full_name = f"{prefix}.{name}" if prefix else name + + # Apply LoRA to matching Linear layers + if isinstance(child, nn.Linear) and _should_apply_lora(name, full_name): + # Replace with LoRA layer + lora_layer = LoRALayer( + base_layer=child, + r=config.r, + alpha=config.alpha, + dropout=config.dropout, + ) + setattr(module, name, lora_layer) + lora_layers[full_name] = lora_layer + + logger.debug( + f"Applied LoRA to {full_name} " + f"(in={child.in_features}, out={child.out_features})" + ) + else: + # Recurse into child + _inject_lora_recursive(child, full_name) + + _inject_lora_recursive(model) + + if not lora_layers: + logger.warning( + f"No LoRA layers were applied. Target modules {config.target_modules} " + f"not found. Check your target_modules configuration." + ) + else: + logger.info(f"Applied LoRA to {len(lora_layers)} layers") + + return model, lora_layers + + +def get_lora_parameters(model: nn.Module) -> List[nn.Parameter]: + """ + Extract all LoRA parameters (lora_A and lora_B) from model. + + Parameters + ---------- + model : nn.Module + Model with LoRA layers. + + Returns + ------- + List[nn.Parameter] + List of LoRA parameters. + """ + lora_params = [] + for module in model.modules(): + if isinstance(module, LoRALayer): + lora_params.extend([module.lora_A, module.lora_B]) + return lora_params + + +def get_lora_state_dict(model: nn.Module) -> Dict[str, torch.Tensor]: + """ + Get state dict containing only LoRA parameters. + + Parameters + ---------- + model : nn.Module + Model with LoRA layers. + + Returns + ------- + Dict[str, torch.Tensor] + State dict with LoRA parameters only. + """ + lora_state = {} + for name, module in model.named_modules(): + if isinstance(module, LoRALayer): + lora_state[f"{name}.lora_A"] = module.lora_A.data + lora_state[f"{name}.lora_B"] = module.lora_B.data + return lora_state + + +def merge_lora_weights(model: nn.Module) -> int: + """ + Merge all LoRA weights into base layers and remove LoRA structure. + + After calling this, the model will have standard Linear layers with + merged weights. LoRA adapters are removed from the model. + + Parameters + ---------- + model : nn.Module + Model with LoRA layers. + + Returns + ------- + int + Number of layers merged and removed. + """ + count = 0 + already_merged = 0 + for module in model.modules(): + if isinstance(module, LoRALayer): + if not module.merged: + module.merge_weights() + count += 1 + else: + already_merged += 1 + + if count > 0: + logger.debug(f"Merged LoRA weights in {count} layers") + if already_merged > 0: + logger.debug(f"Skipped {already_merged} layers (already merged)") + + # Remove LoRA layers after merging + if count > 0 or already_merged > 0: + remove_lora_from_model(model) + logger.info(f"Merged and removed LoRA layers from model") + + return count + + +def unmerge_lora_weights(model: nn.Module) -> int: + """ + Unmerge all LoRA weights from their base layers. + + Parameters + ---------- + model : nn.Module + Model with LoRA layers. + + Returns + ------- + int + Number of layers unmerged. + """ + count = 0 + not_merged = 0 + for module in model.modules(): + if isinstance(module, LoRALayer): + if module.merged: + module.unmerge_weights() + count += 1 + else: + not_merged += 1 + + if count > 0: + logger.debug(f"Unmerged LoRA weights in {count} layers") + if not_merged > 0: + logger.debug(f"Skipped {not_merged} layers (not merged)") + return count + + +def count_lora_parameters(model: nn.Module) -> Tuple[int, int, float]: + """ + Count LoRA parameters vs total parameters. + + Parameters + ---------- + model : nn.Module + Model with LoRA layers. + + Returns + ------- + lora_params : int + Number of trainable LoRA parameters. + total_params : int + Total number of model parameters. + percentage : float + Percentage of trainable parameters. + """ + lora_params = sum(p.numel() for p in get_lora_parameters(model)) + total_params = sum(p.numel() for p in model.parameters()) + percentage = (lora_params / total_params * 100) if total_params > 0 else 0.0 + + return lora_params, total_params, percentage + + +def print_lora_info(model: nn.Module, config: LoRAConfig): + """ + Print detailed LoRA configuration and parameter statistics. + + Parameters + ---------- + model : nn.Module + Model with LoRA layers. + config : LoRAConfig + LoRA configuration. + """ + lora_params, total_params, percentage = count_lora_parameters(model) + + # Count LoRA layers + num_lora_layers = sum(1 for m in model.modules() if isinstance(m, LoRALayer)) + + print("=" * 70) + print("πŸ”§ LoRA Configuration") + print("=" * 70) + print(f"Enabled : {config.enabled}") + print(f"Rank (r) : {config.r}") + print(f"Alpha : {config.alpha}") + print(f"Scaling (Ξ±/r) : {config.alpha / config.r:.4f}") + print(f"Dropout : {config.dropout}") + print(f"Target modules : {', '.join(config.target_modules)}") + print(f"LoRA layers : {num_lora_layers}") + print("-" * 70) + print(f"Trainable params : {lora_params:,} / {total_params:,} ({percentage:.2f}%)") + print(f"Memory savings : ~{100 - percentage:.1f}% fewer gradients") + print("=" * 70) + + +def remove_lora_from_model(model: nn.Module) -> nn.Module: + """ + Remove LoRA layers and restore original Linear layers. + Useful for inference with merged weights. + + Parameters + ---------- + model : nn.Module + Model with LoRA layers. + + Returns + ------- + nn.Module + Model with LoRA layers replaced by standard Linear layers. + """ + def _remove_lora_recursive(module: nn.Module): + for name, child in module.named_children(): + if isinstance(child, LoRALayer): + # Ensure weights are merged + if not child.merged: + child.merge_weights() + # Replace LoRALayer with its base layer + setattr(module, name, child.base_layer) + logger.debug(f"Removed LoRA from {name}, restored base layer") + else: + _remove_lora_recursive(child) + + _remove_lora_recursive(model) + logger.info("Removed all LoRA layers from model") + return model + + +# ============================================================================= +# Adapter Management Functions +# ============================================================================= + +def save_lora_adapter( + model: nn.Module, + save_path: Union[str, Path], +) -> None: + """ + Save only LoRA adapter weights and config. + + Args: + model: Model with LoRA layers (must NOT be merged) + save_path: Directory to save adapter + + Saves: + - adapter_config.json + - adapter_weights.safetensors + """ + save_path = Path(save_path) + save_path.mkdir(parents=True, exist_ok=True) + + # Collect LoRA weights and config + lora_state = {} + lora_config = None + + for name, module in model.named_modules(): + if isinstance(module, LoRALayer): + if module.merged: + raise ValueError( + "Cannot save adapter with merged weights. " + "Call unmerge_lora_weights() first." + ) + # Save LoRA matrices with full path from model root + lora_state[f"{name}.lora_A"] = module.lora_A.data + lora_state[f"{name}.lora_B"] = module.lora_B.data + + # Extract config from first LoRA layer + if lora_config is None: + lora_config = { + "lora_r": module.r, + "lora_alpha": module.alpha, + "lora_dropout": module.lora_dropout.p if hasattr(module.lora_dropout, 'p') else 0.0, + } + + if not lora_state: + raise ValueError("No LoRA layers found in model") + + # Save weights + weights_path = save_path / "adapter_weights.safetensors" + save_file(lora_state, str(weights_path)) + logger.info(f"Saved {len(lora_state)} LoRA tensors to {weights_path}") + + # Determine target modules from layer names + # Extract top-level module names (encoder, span_rep, classifier, etc.) + target_modules = set() + for key in lora_state.keys(): + # Extract first level module from full path + # e.g., "encoder.layer.0.attention.self.query.lora_A" -> "encoder" + # e.g., "span_rep.project_start.0.lora_A" -> "span_rep" + parts = key.split(".") + if len(parts) > 0: + # Get the first level module name + module_name = parts[0] + target_modules.add(module_name) + + # Create and save adapter config + adapter_config = LoRAAdapterConfig( + adapter_type="lora", + adapter_version="1.0", + lora_r=lora_config["lora_r"], + lora_alpha=lora_config["lora_alpha"], + lora_dropout=lora_config["lora_dropout"], + target_modules=sorted(list(target_modules)), + created_at=datetime.utcnow().isoformat() + "Z" + ) + adapter_config.save(save_path) + + logger.info(f"Saved LoRA adapter to {save_path}") + + +def load_lora_adapter( + model: nn.Module, + adapter_path: Union[str, Path], + auto_unload: bool = True, +) -> Dict[str, LoRALayer]: + """ + Load LoRA adapter onto model. + + Args: + model: Base model (should not have LoRA applied) + adapter_path: Path to adapter directory + auto_unload: If True, unload existing adapter first + + Returns: + Dict of LoRA layers that were applied + """ + adapter_path = Path(adapter_path) + + # Load adapter config + adapter_config = LoRAAdapterConfig.load(adapter_path) + + # Unload existing adapter if requested + if auto_unload and has_lora_adapter(model): + logger.info("Unloading existing adapter before loading new one") + unload_lora_adapter(model) + + # Load adapter weights + weights_path = adapter_path / "adapter_weights.safetensors" + if not weights_path.exists(): + raise FileNotFoundError(f"Adapter weights not found at {weights_path}") + + lora_state = load_file(str(weights_path)) + logger.info(f"Loaded {len(lora_state)} LoRA tensors from {weights_path}") + + # Apply LoRA to matching layers + lora_config = LoRAConfig( + enabled=True, + r=adapter_config.lora_r, + alpha=adapter_config.lora_alpha, + dropout=adapter_config.lora_dropout, + target_modules=adapter_config.target_modules, + ) + + model, lora_layers = apply_lora_to_model(model, lora_config) + + # Load saved weights into LoRA layers + for name, module in model.named_modules(): + if isinstance(module, LoRALayer): + lora_a_key = f"{name}.lora_A" + lora_b_key = f"{name}.lora_B" + + if lora_a_key in lora_state and lora_b_key in lora_state: + # Move loaded tensors to the same device as the module + device = next(module.parameters()).device + module.lora_A.data = lora_state[lora_a_key].to(device) + module.lora_B.data = lora_state[lora_b_key].to(device) + logger.debug(f"Loaded weights for {name}") + else: + logger.warning(f"No saved weights found for {name}") + + logger.info(f"Loaded LoRA adapter from {adapter_path}") + return lora_layers + + +def unload_lora_adapter(model: nn.Module) -> int: + """ + Remove all LoRA layers, restoring original Linear layers. + + Unlike remove_lora_from_model, this does NOT merge weights. + Just removes LoRA layers entirely. + + Returns: + Number of layers unloaded + """ + count = 0 + + def _get_parent_module(model: nn.Module, full_name: str) -> Tuple[nn.Module, str]: + """Get parent module and child name from full module path.""" + parts = full_name.split('.') + parent = model + for part in parts[:-1]: + parent = getattr(parent, part) + return parent, parts[-1] + + # Collect all LoRA layers first (to avoid modifying dict during iteration) + lora_layers = [] + for name, module in model.named_modules(): + if isinstance(module, LoRALayer): + lora_layers.append((name, module)) + + # Remove LoRA layers + for name, lora_layer in lora_layers: + parent, child_name = _get_parent_module(model, name) + # Replace with original base_layer (no merge) + setattr(parent, child_name, lora_layer.base_layer) + count += 1 + logger.debug(f"Unloaded LoRA from {name}") + + if count > 0: + logger.info(f"Unloaded {count} LoRA layers") + + return count + + +def has_lora_adapter(model: nn.Module) -> bool: + """Check if model has LoRA layers applied.""" + for module in model.modules(): + if isinstance(module, LoRALayer): + return True + return False + + +def get_adapter_config(model: nn.Module) -> Optional[LoRAAdapterConfig]: + """ + Get config of currently loaded adapter, if any. + + Note: This reconstructs config from LoRA layers. + The actual adapter config is stored in model._adapter_config + when loaded via model.load_adapter(). + """ + if not has_lora_adapter(model): + return None + + # Extract config from first LoRA layer + for module in model.modules(): + if isinstance(module, LoRALayer): + target_modules = set() + # Collect all target module groups (top-level modules) + for name, m in model.named_modules(): + if isinstance(m, LoRALayer): + # Extract first level module name + parts = name.split(".") + if parts: + target_modules.add(parts[0]) + + return LoRAAdapterConfig( + adapter_type="lora", + adapter_version="1.0", + lora_r=module.r, + lora_alpha=module.alpha, + lora_dropout=module.lora_dropout.p if hasattr(module.lora_dropout, 'p') else 0.0, + target_modules=sorted(list(target_modules)), + created_at="" + ) + + return None + diff --git a/packages/GLiNER2/gliner2/training/trainer.py b/packages/GLiNER2/gliner2/training/trainer.py new file mode 100644 index 0000000..766b024 --- /dev/null +++ b/packages/GLiNER2/gliner2/training/trainer.py @@ -0,0 +1,1409 @@ +""" +GLiNER2 World-Class Trainer +=========================== + +Production-grade training infrastructure with flexible data input. + +Supported Data Formats: +----------------------- +1. Single JSONL file path (str or Path) +2. List of JSONL file paths +3. List of InputExample objects +4. TrainingDataset object +5. List of raw dict records ({"input": ..., "output": ...} format) + +Basic Examples: +-------------- + >>> from gliner2.training.data import InputExample, TrainingDataset + >>> from gliner2.training.trainer import TrainingConfig, GLiNER2Trainer + >>> +>>> # 1. From list of InputExample + >>> examples = [ + ... InputExample(text="John works at Google.", entities={"person": ["John"], "company": ["Google"]}), + ... InputExample(text="Apple released iPhone.", entities={"company": ["Apple"], "product": ["iPhone"]}), + ... ] + >>> trainer = GLiNER2Trainer(model, config) + >>> trainer.train(train_data=examples) + >>> +>>> # 2. From JSONL file(s) + >>> trainer.train(train_data="train.jsonl") + >>> trainer.train(train_data=["train1.jsonl", "train2.jsonl"]) +>>> +>>> # 3. From TrainingDataset +>>> dataset = TrainingDataset.load("train.jsonl") +>>> trainer.train(train_data=dataset) +""" + +from __future__ import annotations + +import gc +import json +import logging +import math +import os +import random +import shutil +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass, field, asdict +from datetime import datetime +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import GradScaler, autocast +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from tqdm.auto import tqdm + +from gliner2.processor import SchemaTransformer, SamplingConfig + +# Import training data classes +from gliner2.training.data import ( + InputExample, TrainingDataset, ValidationError, + DataFormat, detect_data_format, DataLoader_Factory, TrainDataInput +) + +# Import LoRA for parameter-efficient fine-tuning +from gliner2.training.lora import ( + LoRAConfig, apply_lora_to_model, get_lora_parameters, + merge_lora_weights, count_lora_parameters, print_lora_info +) + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Configuration +# ============================================================================= + +@dataclass +class TrainingConfig: + """ + Complete training configuration. + + Parameters + ---------- + output_dir : str + Directory for saving checkpoints and logs. + experiment_name : str + Name of the experiment (used for logging). + num_epochs : int + Number of training epochs. + max_steps : int + Maximum training steps (-1 = determined by epochs). + batch_size : int + Training batch size per device. + eval_batch_size : int + Evaluation batch size. + gradient_accumulation_steps : int + Number of gradient accumulation steps. + encoder_lr : float + Learning rate for encoder parameters. + task_lr : float + Learning rate for task-specific parameters. + weight_decay : float + Weight decay for AdamW optimizer. + max_grad_norm : float + Maximum gradient norm for clipping. + scheduler_type : str + LR scheduler type: "linear", "cosine", "cosine_restarts", "constant". + warmup_ratio : float + Warmup ratio (portion of total steps). + warmup_steps : int + Explicit warmup steps (overrides warmup_ratio if > 0). + fp16 : bool + Use FP16 mixed precision. + bf16 : bool + Use BF16 mixed precision. + eval_strategy : str + When to evaluate and save: "epoch", "steps", or "no". + eval_steps : int + Evaluate and save every N steps (if eval_strategy="steps"). + save_total_limit : int + Maximum checkpoints to keep. + save_best : bool + Save best model based on metric. + metric_for_best : str + Metric to use for best model selection. + greater_is_better : bool + Whether higher metric is better. + logging_steps : int + Log every N steps (updates progress bar metrics). + report_to_wandb : bool + Enable Weights & Biases logging. + wandb_project : str, optional + W&B project name. + early_stopping : bool + Enable early stopping. + early_stopping_patience : int + Patience for early stopping. + num_workers : int + DataLoader workers. + seed : int + Random seed. + validate_data : bool + Validate training data before training. + use_lora : bool + Enable LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning. + lora_r : int + LoRA rank (bottleneck dimension). Higher = more parameters but better approximation. + Typical values: 4, 8, 16, 32, 64. + lora_alpha : float + LoRA scaling factor. Final scaling is alpha/r. Typical: 2*r. + lora_dropout : float + Dropout probability for LoRA layers. + lora_target_modules : List[str] + Module groups to apply LoRA to. Options: + - "encoder": All encoder layers (query, key, value, dense) + - "encoder.query": Only query layers in encoder + - "encoder.key": Only key layers in encoder + - "encoder.value": Only value layers in encoder + - "encoder.dense": Only dense (FFN) layers in encoder + - "span_rep": All linear layers in span representation + - "classifier": All linear layers in classifier head + - "count_embed": All linear layers in count embedding + - "count_pred": All linear layers in count prediction + Default: All modules for maximum adaptation. + save_adapter_only : bool + When use_lora=True, save only adapter weights (not full model). + """ + output_dir: str = "./output" + experiment_name: str = "gliner2" + num_epochs: int = 10 + max_steps: int = -1 + batch_size: int = 2 + eval_batch_size: int = 8 + gradient_accumulation_steps: int = 1 + encoder_lr: float = 1e-5 + task_lr: float = 5e-4 + weight_decay: float = 0.01 + adam_beta1: float = 0.9 + adam_beta2: float = 0.999 + adam_epsilon: float = 1e-8 + max_grad_norm: float = 1.0 + scheduler_type: str = "linear" + warmup_ratio: float = 0.1 + warmup_steps: int = 0 + num_cycles: float = 0.5 + fp16: bool = True + bf16: bool = False + eval_strategy: str = "steps" + eval_steps: int = 500 + save_total_limit: int = 3 + save_best: bool = True + metric_for_best: str = "eval_loss" + greater_is_better: bool = False + logging_steps: int = 1 + logging_first_step: bool = True + report_to_wandb: bool = False + wandb_project: Optional[str] = None + wandb_entity: Optional[str] = None + wandb_run_name: Optional[str] = None + wandb_tags: List[str] = field(default_factory=list) + wandb_notes: Optional[str] = None + early_stopping: bool = False + early_stopping_patience: int = 3 + early_stopping_threshold: float = 0.0 + num_workers: int = 4 + pin_memory: bool = True + prefetch_factor: int = 2 + seed: int = 42 + deterministic: bool = False + local_rank: int = -1 + debug: bool = False + max_train_samples: int = -1 + max_eval_samples: int = -1 + validate_data: bool = True + + # LoRA Configuration (Parameter-Efficient Fine-Tuning) + use_lora: bool = False + lora_r: int = 16 + lora_alpha: float = 32.0 + lora_dropout: float = 0.0 + lora_target_modules: List[str] = field(default_factory=lambda: ["encoder", "span_rep", "classifier", "count_embed", "count_pred"]) + save_adapter_only: bool = True # Only applies when use_lora=True + + def __post_init__(self): + if self.fp16 and self.bf16: + raise ValueError("Cannot use both fp16 and bf16") + if self.bf16 and not torch.cuda.is_bf16_supported(): + logger.warning("bf16 not supported, falling back to fp16") + self.bf16 = False + self.fp16 = True + + # Validate logging_steps + if self.logging_steps <= 0: + raise ValueError(f"logging_steps must be > 0, got {self.logging_steps}") + + # Validate batch_size + if self.batch_size <= 0: + raise ValueError(f"batch_size must be > 0, got {self.batch_size}") + + if self.eval_batch_size <= 0: + raise ValueError(f"eval_batch_size must be > 0, got {self.eval_batch_size}") + + # Validate gradient_accumulation_steps + if self.gradient_accumulation_steps <= 0: + raise ValueError(f"gradient_accumulation_steps must be > 0, got {self.gradient_accumulation_steps}") + + # Validate LoRA configuration + if self.use_lora: + if self.lora_r <= 0: + raise ValueError(f"lora_r must be > 0, got {self.lora_r}") + if self.lora_alpha <= 0: + raise ValueError(f"lora_alpha must be > 0, got {self.lora_alpha}") + if not 0 <= self.lora_dropout < 1: + raise ValueError(f"lora_dropout must be in [0, 1), got {self.lora_dropout}") + if not self.lora_target_modules: + raise ValueError("lora_target_modules cannot be empty when use_lora=True") + + @property + def effective_batch_size(self) -> int: + return self.batch_size * self.gradient_accumulation_steps + + def save(self, path: str): + with open(path, 'w') as f: + json.dump(asdict(self), f, indent=2) + + @classmethod + def load(cls, path: str) -> 'TrainingConfig': + with open(path) as f: + return cls(**json.load(f)) + + +# ============================================================================= +# Dataset +# ============================================================================= + +class ExtractorDataset(Dataset): + """ + Dataset for GLiNER2 training with multi-format support. + + Supports all formats through DataLoader_Factory: + - JSONL file path(s) + - List of InputExample objects + - TrainingDataset object + - List of raw dict records + + Examples + -------- + >>> # From JSONL + >>> dataset = ExtractorDataset("train.jsonl") + + >>> # From multiple JSONL files + >>> dataset = ExtractorDataset(["train1.jsonl", "train2.jsonl"]) + + >>> # From InputExample list + >>> dataset = ExtractorDataset(examples) + """ + + def __init__( + self, + data: TrainDataInput, + max_samples: int = -1, + shuffle: bool = True, + seed: int = 42, + validate: bool = False, + ): + """ + Initialize dataset from various input formats. + + Parameters + ---------- + data : TrainDataInput + Training data in any supported format. + max_samples : int, default=-1 + Maximum samples to use (-1 = all). + shuffle : bool, default=True + Whether to shuffle the data. + seed : int, default=42 + Random seed for shuffling. + validate : bool, default=False + Whether to validate the data. Validation is always strict: + checks that entity spans, relation values, and structure + field values exist in the text. + """ + self.data = DataLoader_Factory.load( + data=data, + max_samples=max_samples, + shuffle=shuffle, + seed=seed, + validate=validate, + ) + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, idx: int) -> Tuple[str, Dict]: + record = self.data[idx] + # Handle both formats + if "input" in record: + return record["input"], record["output"] + else: + return record["text"], record["schema"] + + # Factory methods for explicit creation + @classmethod + def from_jsonl(cls, paths: Union[str, Path, List], **kwargs) -> 'ExtractorDataset': + """Create from JSONL file(s).""" + return cls(paths, **kwargs) + + @classmethod + def from_examples(cls, examples: List[InputExample], **kwargs) -> 'ExtractorDataset': + """Create from list of InputExample.""" + return cls(examples, **kwargs) + + @classmethod + def from_training_dataset(cls, dataset: TrainingDataset, **kwargs) -> 'ExtractorDataset': + """Create from TrainingDataset.""" + return cls(dataset, **kwargs) + + @classmethod + def from_dicts(cls, dicts: List[Dict], **kwargs) -> 'ExtractorDataset': + """Create from list of dicts.""" + return cls(dicts, **kwargs) + + +# ============================================================================= +# Collator +# ============================================================================= + +class ExtractorCollator: + """Data collator that converts raw records to model inputs.""" + + def __init__(self, processor: SchemaTransformer, is_training: bool = True): + self.processor = processor + self.is_training = is_training + + def __call__(self, batch: List[Tuple[str, Dict]]): + """ + Convert batch of (text, schema) tuples to PreprocessedBatch. + + Args: + batch: List of (text, schema) tuples from dataset + + Returns: + PreprocessedBatch ready for model.forward() + """ + if self.is_training: + return self.processor.collate_fn_train(batch) + else: + return self.processor.collate_fn_inference(batch) + + +# ============================================================================= +# Metrics +# ============================================================================= + +@dataclass +class TrainingMetrics: + """Container for training metrics.""" + loss: float = 0.0 + classification_loss: float = 0.0 + structure_loss: float = 0.0 + count_loss: float = 0.0 + learning_rate: float = 0.0 + epoch: float = 0.0 + step: int = 0 + samples_seen: int = 0 + throughput: float = 0.0 + + def to_dict(self) -> Dict[str, float]: + return asdict(self) + + +# ============================================================================= +# Scheduler Factory +# ============================================================================= + +def get_scheduler(optimizer, scheduler_type, num_training_steps, num_warmup_steps, num_cycles=0.5): + """Create learning rate scheduler.""" + def lr_lambda_linear(step): + if step < num_warmup_steps: + return float(step) / float(max(1, num_warmup_steps)) + return max(0.0, float(num_training_steps - step) / float(max(1, num_training_steps - num_warmup_steps))) + + def lr_lambda_cosine(step): + if step < num_warmup_steps: + return float(step) / float(max(1, num_warmup_steps)) + progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) + + def lr_lambda_cosine_restarts(step): + if step < num_warmup_steps: + return float(step) / float(max(1, num_warmup_steps)) + progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((num_cycles * progress) % 1.0)))) + + def lr_lambda_constant(step): + if step < num_warmup_steps: + return float(step) / float(max(1, num_warmup_steps)) + return 1.0 + + schedulers = { + "linear": lr_lambda_linear, + "cosine": lr_lambda_cosine, + "cosine_restarts": lr_lambda_cosine_restarts, + "constant": lr_lambda_constant, + } + + if scheduler_type not in schedulers: + raise ValueError(f"Unknown scheduler: {scheduler_type}") + + return LambdaLR(optimizer, schedulers[scheduler_type]) + + +# ============================================================================= +# Main Trainer +# ============================================================================= + +class GLiNER2Trainer: + """ + World-class trainer for GLiNER2 with flexible multi-format data input. + + Parameters + ---------- + model : nn.Module + The GLiNER2 model to train. + config : TrainingConfig + Training configuration. + processor : SchemaTransformer, optional + Schema processor. If None, uses model.processor. + train_data : TrainDataInput, optional + Training data (can be provided here or in train()). + eval_data : TrainDataInput, optional + Evaluation data. + compute_metrics : Callable, optional + Custom metrics function. + + Supported Data Formats + ---------------------- + - Single JSONL file path (str or Path) + - List of JSONL file paths + - List of InputExample objects + - TrainingDataset object + - List of raw dict records + + Examples + -------- + >>> # With InputExample list + >>> examples = [InputExample(...), InputExample(...)] + >>> trainer = GLiNER2Trainer(model, config) + >>> trainer.train(train_data=examples) + + >>> # With JSONL file + >>> trainer.train(train_data="train.jsonl") + + >>> # With multiple JSONL files + >>> trainer.train(train_data=["train1.jsonl", "train2.jsonl"]) + + >>> # With TrainingDataset + >>> dataset = TrainingDataset.load("train.jsonl") + >>> trainer.train(train_data=dataset) + """ + + def __init__( + self, + model: nn.Module, + config: TrainingConfig, + processor: SchemaTransformer = None, + train_data: TrainDataInput = None, + eval_data: TrainDataInput = None, + compute_metrics: Optional[Callable] = None, + ): + self.model = model + self.config = config + self.processor = processor or getattr(model, 'processor', None) + if self.processor is None: + raise ValueError("Processor must be provided or model must have .processor attribute") + + self.train_data = train_data + self.eval_data = eval_data + self.compute_metrics = compute_metrics + + self._setup_seed() + self._setup_device() + self._setup_output_dir() + self._setup_logging() + + self.global_step = 0 + self.epoch = 0 + self.best_metric = float('inf') if not config.greater_is_better else float('-inf') + self.patience_counter = 0 + self.train_metrics_history = [] + self.eval_metrics_history = [] + + self.optimizer = None + self.scheduler = None + self.scaler = None + self.wandb_run = None + self.progress_bar = None + + # LoRA state + self.lora_layers = {} + self._setup_lora() + + def _setup_seed(self): + seed = self.config.seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if self.config.deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: + torch.backends.cudnn.benchmark = True + + def _setup_device(self): + if self.config.local_rank >= 0: + torch.cuda.set_device(self.config.local_rank) + self.device = torch.device("cuda", self.config.local_rank) + self.is_distributed = True + elif torch.cuda.is_available(): + self.device = torch.device("cuda") + self.is_distributed = False + else: + self.device = torch.device("cpu") + self.is_distributed = False + if self.config.fp16 or self.config.bf16: + logger.warning("Mixed precision disabled on CPU") + self.config.fp16 = False + self.config.bf16 = False + self.model.to(self.device) + logger.info(f"Using device: {self.device}") + + def _setup_output_dir(self): + self.output_dir = Path(self.config.output_dir) + self.logs_dir = self.output_dir / "logs" + if self.is_main_process: + self.output_dir.mkdir(parents=True, exist_ok=True) + self.logs_dir.mkdir(exist_ok=True) + self.config.save(str(self.output_dir / "training_config.json")) + + def _setup_logging(self): + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO if self.is_main_process else logging.WARNING, + ) + + # W&B setup (HuggingFace style) + self.wandb_run = None + if self.config.report_to_wandb and self.is_main_process: + try: + import wandb + self.wandb_run = wandb.init( + project=self.config.wandb_project or self.config.experiment_name, + entity=self.config.wandb_entity, + name=self.config.wandb_run_name or f"{self.config.experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", + config=asdict(self.config), + tags=self.config.wandb_tags, + notes=self.config.wandb_notes, + dir=str(self.output_dir), + ) + logger.info(f"W&B run: {self.wandb_run.url}") + except ImportError: + logger.warning("wandb not installed. Run: pip install wandb") + self.config.report_to_wandb = False + + def _setup_lora(self): + """Setup LoRA for parameter-efficient fine-tuning if enabled.""" + if not self.config.use_lora: + logger.info("LoRA is disabled") + return + + logger.info("Setting up LoRA for parameter-efficient fine-tuning...") + + # Freeze ALL model parameters BEFORE applying LoRA + for param in self.model.parameters(): + param.requires_grad = False + logger.info("Froze all model parameters for LoRA training") + + # Create LoRA config + lora_config = LoRAConfig( + enabled=True, + r=self.config.lora_r, + alpha=self.config.lora_alpha, + dropout=self.config.lora_dropout, + target_modules=self.config.lora_target_modules, + ) + + # Apply LoRA (encoder: targeted modules, non-encoder: all linear layers) + # LoRA layers' lora_A and lora_B are nn.Parameter created after freezing, + # so they have requires_grad=True by default - only these get trained + self.model, self.lora_layers = apply_lora_to_model( + model=self.model, + config=lora_config, + ) + + # Sync model's _lora_layers attribute + self.model._lora_layers = self.lora_layers + + # Print LoRA information + if self.is_main_process: + print_lora_info(self.model, lora_config) + + # Log parameter counts + lora_params, total_params, percentage = count_lora_parameters(self.model) + logger.info( + f"LoRA setup complete: {lora_params:,} trainable params " + f"out of {total_params:,} total ({percentage:.2f}%)" + ) + + @property + def is_main_process(self) -> bool: + return self.config.local_rank <= 0 + + @staticmethod + def _safe_divide(numerator: float, denominator: float, default: float = 0.0) -> float: + """Safely divide two numbers, returning default if denominator is zero.""" + if denominator == 0: + return default + return numerator / denominator + + def _validate_training_setup(self, train_dataset: ExtractorDataset, eval_dataset: Optional[ExtractorDataset]): + """Validate training setup and raise informative errors for edge cases.""" + # Check if dataset is empty + if len(train_dataset) == 0: + raise ValueError("Training dataset is empty. Please provide at least one training example.") + + # Check if dataset is smaller than batch size + if len(train_dataset) < self.config.batch_size: + logger.warning( + f"Training dataset size ({len(train_dataset)}) is smaller than batch_size " + f"({self.config.batch_size}). Adjusting batch_size to {len(train_dataset)}." + ) + # We'll handle this in _create_dataloader by adjusting drop_last + + # Check early stopping configuration + if self.config.early_stopping: + if eval_dataset is None: + raise ValueError( + "early_stopping is enabled but no eval_data provided. " + "Please provide eval_data or disable early_stopping." + ) + if len(eval_dataset) == 0: + raise ValueError("Evaluation dataset is empty but early_stopping is enabled.") + + # Check eval strategy configuration + if self.config.eval_strategy == "steps" and eval_dataset is None: + logger.warning( + "eval_strategy='steps' but no eval_data provided. " + "Evaluation will be skipped." + ) + + # Warn about very small datasets + if len(train_dataset) < self.config.gradient_accumulation_steps: + logger.warning( + f"Training dataset size ({len(train_dataset)}) is smaller than " + f"gradient_accumulation_steps ({self.config.gradient_accumulation_steps}). " + f"Training may not work as expected." + ) + + def _flush_gradients(self) -> Optional[float]: + """Flush accumulated gradients at the end of epoch if incomplete cycle exists.""" + # Check if there are accumulated gradients + has_gradients = False + for param in self.model.parameters(): + if param.grad is not None and param.grad.abs().sum() > 0: + has_gradients = True + break + + if not has_gradients: + return None + + # Apply the accumulated gradients + if self.config.fp16: + self.scaler.unscale_(self.optimizer) + + grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) + + if self.config.fp16: + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.optimizer.step() + + self.scheduler.step() + self.optimizer.zero_grad() + self.global_step += 1 + + logger.info(f"Flushed incomplete gradient accumulation cycle at end of epoch (grad_norm: {grad_norm:.2f})") + return grad_norm + + def _prepare_data(self, data: TrainDataInput, is_train: bool = True) -> ExtractorDataset: + """Convert any supported data format to ExtractorDataset.""" + if data is None: + return None + + if isinstance(data, ExtractorDataset): + return data + + max_samples = self.config.max_train_samples if is_train else self.config.max_eval_samples + + return ExtractorDataset( + data=data, + max_samples=max_samples, + shuffle=is_train, + seed=self.config.seed, + validate=self.config.validate_data if is_train else False + ) + + def _create_optimizer(self) -> AdamW: + """Create optimizer with appropriate parameters based on LoRA configuration.""" + + if self.config.use_lora: + # When using LoRA: ONLY train LoRA parameters (everything else is frozen) + lora_params = get_lora_parameters(self.model) + + if not lora_params: + raise ValueError("No LoRA parameters found. Check LoRA configuration.") + + logger.info(f"Optimizer: LoRA params only = {len(lora_params)}, LR={self.config.task_lr}") + + return AdamW( + [{"params": lora_params, "lr": self.config.task_lr, "weight_decay": self.config.weight_decay}], + betas=(self.config.adam_beta1, self.config.adam_beta2), + eps=self.config.adam_epsilon, + ) + else: + # Normal training: separate LRs for encoder and task-specific layers + encoder_params = [] + task_params = [] + for name, param in self.model.named_parameters(): + if not param.requires_grad: + continue + if "encoder" in name: + encoder_params.append(param) + else: + task_params.append(param) + + return AdamW( + [ + {"params": encoder_params, "lr": self.config.encoder_lr, "weight_decay": self.config.weight_decay}, + {"params": task_params, "lr": self.config.task_lr, "weight_decay": self.config.weight_decay}, + ], + betas=(self.config.adam_beta1, self.config.adam_beta2), + eps=self.config.adam_epsilon, + ) + + def _create_dataloader(self, dataset: ExtractorDataset, batch_size: int, shuffle: bool = True, is_training: bool = True) -> DataLoader: + sampler = None + if self.is_distributed: + sampler = DistributedSampler(dataset, shuffle=shuffle) + shuffle = False + + collator = ExtractorCollator(self.processor, is_training=is_training) + + # Fix Bug #1 & #9: Handle small datasets + # If dataset is smaller than batch_size, adjust to prevent empty dataloader + effective_batch_size = min(batch_size, len(dataset)) + drop_last = is_training and len(dataset) > batch_size + + # Adjust num_workers for small datasets + effective_num_workers = self.config.num_workers if len(dataset) > self.config.num_workers else 0 + + return DataLoader( + dataset, + batch_size=effective_batch_size, + shuffle=shuffle, + sampler=sampler, + num_workers=effective_num_workers, + pin_memory=self.config.pin_memory, + prefetch_factor=self.config.prefetch_factor if effective_num_workers > 0 else None, + collate_fn=collator, + drop_last=drop_last, + persistent_workers=effective_num_workers > 0, + ) + + def train( + self, + train_data: TrainDataInput = None, + eval_data: TrainDataInput = None, + ) -> Dict[str, Any]: + """ + Main training loop. + + Parameters + ---------- + train_data : TrainDataInput, optional + Training data. Supports all formats: + - str/Path: JSONL file path + - List[str/Path]: Multiple JSONL files + - List[InputExample]: List of examples + - TrainingDataset: Dataset object + - List[Dict]: Raw records + + eval_data : TrainDataInput, optional + Evaluation data (same formats supported). + + Returns + ------- + Dict[str, Any] + Training summary with metrics history. + """ + # Prepare datasets + train_data = train_data or self.train_data + eval_data = eval_data or self.eval_data + + if train_data is None: + raise ValueError("No training data provided") + + train_dataset = self._prepare_data(train_data, is_train=True) + eval_dataset = self._prepare_data(eval_data, is_train=False) if eval_data else None + + # Fix Bug #7: Validate training setup + self._validate_training_setup(train_dataset, eval_dataset) + + train_loader = self._create_dataloader(train_dataset, self.config.batch_size, shuffle=True, is_training=True) + + # Fix Bug #1: Check if dataloader is empty + if len(train_loader) == 0: + raise ValueError( + f"Training dataloader is empty. Dataset size: {len(train_dataset)}, " + f"Batch size: {self.config.batch_size}. Please reduce batch_size or add more data." + ) + + # Calculate steps + num_update_steps_per_epoch = len(train_loader) // self.config.gradient_accumulation_steps + + # Fix Bug #1: Handle case where num_update_steps_per_epoch is 0 + if num_update_steps_per_epoch == 0: + # If gradient accumulation is larger than dataloader, we have at least the batches we can process + num_update_steps_per_epoch = 1 + logger.warning( + f"gradient_accumulation_steps ({self.config.gradient_accumulation_steps}) is larger than " + f"batches per epoch ({len(train_loader)}). Setting to 1 update step per epoch." + ) + + if self.config.max_steps > 0: + max_steps = self.config.max_steps + num_epochs = math.ceil(max_steps / num_update_steps_per_epoch) + else: + max_steps = num_update_steps_per_epoch * self.config.num_epochs + num_epochs = self.config.num_epochs + + warmup_steps = self.config.warmup_steps or int(max_steps * self.config.warmup_ratio) + + # Create optimizer and scheduler + self.optimizer = self._create_optimizer() + self.scheduler = get_scheduler(self.optimizer, self.config.scheduler_type, max_steps, warmup_steps, self.config.num_cycles) + + # Mixed precision + use_amp = self.config.fp16 or self.config.bf16 + amp_dtype = torch.bfloat16 if self.config.bf16 else torch.float16 + self.scaler = GradScaler(enabled=self.config.fp16) + + # Logging + logger.info("***** Running Training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num epochs = {num_epochs}") + logger.info(f" Batch size = {self.config.batch_size}") + logger.info(f" Gradient accumulation steps = {self.config.gradient_accumulation_steps}") + logger.info(f" Effective batch size = {self.config.effective_batch_size}") + logger.info(f" Total optimization steps = {max_steps}") + logger.info(f" Warmup steps = {warmup_steps}") + + # Log trainable parameters + if self.config.use_lora: + lora_params, total_params, percentage = count_lora_parameters(self.model) + logger.info(f" LoRA enabled: {lora_params:,} trainable / {total_params:,} total ({percentage:.2f}%)") + else: + trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + total_params = sum(p.numel() for p in self.model.parameters()) + percentage = (trainable_params / total_params * 100) if total_params > 0 else 0.0 + logger.info(f" Trainable parameters: {trainable_params:,} / {total_params:,} ({percentage:.2f}%)") + + # Training state + self.model.train() + self.processor.change_mode(is_training=True) + self.global_step = 0 + self.epoch = 0 + tr_loss = 0.0 + + start_time = time.time() + samples_seen = 0 + + self.progress_bar = tqdm(total=max_steps, desc="Training", disable=not self.is_main_process) + + for epoch in range(num_epochs): + self.epoch = epoch + + if self.is_distributed: + train_loader.sampler.set_epoch(epoch) + + epoch_loss = 0.0 + epoch_steps = 0 + + for step, batch in enumerate(train_loader): + samples_seen += len(batch) + + try: + with autocast(enabled=use_amp, dtype=amp_dtype): + outputs = self.model(batch) + loss = outputs["total_loss"] + + if self.config.gradient_accumulation_steps > 1: + loss = loss / self.config.gradient_accumulation_steps + + # Skip batches where loss doesn't require grad (edge cases in data) + if not loss.requires_grad: + logger.warning( + f"Skipping batch {step}: loss doesn't require grad " + f"(loss={loss.item():.4f}). This may indicate edge cases in your data." + ) + continue + + if self.config.fp16: + self.scaler.scale(loss).backward() + else: + loss.backward() + + tr_loss += loss.item() + epoch_loss += loss.item() + epoch_steps += 1 + + except torch.cuda.OutOfMemoryError: + logger.warning( + f"OOM at step {step}, batch skipped. " + f"Consider reducing batch_size or max sequence length." + ) + torch.cuda.empty_cache() + gc.collect() + self.optimizer.zero_grad() + continue + + if (step + 1) % self.config.gradient_accumulation_steps == 0: + if self.config.fp16: + self.scaler.unscale_(self.optimizer) + + grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) + + if self.config.fp16: + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.optimizer.step() + + self.scheduler.step() + self.optimizer.zero_grad() + self.global_step += 1 + + if self.global_step % self.config.logging_steps == 0: + elapsed = time.time() - start_time + # Fix Bug #2: Safe division for metrics + avg_loss = self._safe_divide(tr_loss, self.config.logging_steps, default=tr_loss) + # Fix Bug #5: Safe division for epoch progress + epoch_progress = self._safe_divide(step, len(train_loader), default=0.0) + metrics = TrainingMetrics( + loss=avg_loss, + classification_loss=outputs.get("classification_loss", torch.tensor(0)).item(), + structure_loss=outputs.get("structure_loss", torch.tensor(0)).item(), + count_loss=outputs.get("count_loss", torch.tensor(0)).item(), + learning_rate=self.scheduler.get_last_lr()[0], + epoch=epoch + epoch_progress, + step=self.global_step, + samples_seen=samples_seen, + throughput=self._safe_divide(samples_seen, elapsed, default=0.0), + ) + self._log_metrics(metrics, prefix="train") + tr_loss = 0.0 + + if self.config.eval_strategy == "steps" and self.global_step % self.config.eval_steps == 0: + if eval_dataset: + self._evaluate(eval_dataset) + self.model.train() + self.processor.change_mode(is_training=True) + self._save_checkpoint(f"checkpoint-{self.global_step}") + + self.progress_bar.update(1) + + if self.global_step >= max_steps: + break + + # Fix Bug #6: Flush incomplete gradient accumulation at end of epoch + if epoch_steps % self.config.gradient_accumulation_steps != 0: + grad_norm = self._flush_gradients() + if grad_norm is not None: + logger.info(f"Applied incomplete gradient accumulation at end of epoch {epoch + 1}") + + # Fix Bug #3: Safe division for epoch loss + avg_epoch_loss = self._safe_divide(epoch_loss, epoch_steps, default=0.0) + logger.info(f"Epoch {epoch + 1}/{num_epochs} - Loss: {avg_epoch_loss:.4f}") + + if self.config.eval_strategy == "epoch": + if eval_dataset: + eval_metrics = self._evaluate(eval_dataset) + self.model.train() + self.processor.change_mode(is_training=True) + if self.config.early_stopping and self._check_early_stopping(eval_metrics): + logger.info(f"Early stopping triggered at epoch {epoch + 1}") + break + self._save_checkpoint(f"checkpoint-epoch-{epoch + 1}") + + if self.global_step >= max_steps: + break + + self.progress_bar.close() + self.progress_bar = None + + if self.is_main_process: + self._save_checkpoint("final") + if self.config.report_to_wandb: + import wandb + wandb.summary["best_metric"] = self.best_metric + wandb.summary["total_steps"] = self.global_step + wandb.finish() + + total_time = time.time() - start_time + return { + "total_steps": self.global_step, + "total_epochs": self.epoch + 1, + "total_time_seconds": total_time, + "samples_per_second": samples_seen / total_time, + "best_metric": self.best_metric, + "train_metrics_history": self.train_metrics_history, + "eval_metrics_history": self.eval_metrics_history, + } + + def _evaluate(self, eval_dataset: ExtractorDataset) -> Dict[str, float]: + logger.info("Running evaluation...") + self.model.eval() + self.processor.change_mode(is_training=False) + + eval_loader = self._create_dataloader(eval_dataset, self.config.eval_batch_size, shuffle=False, is_training=False) + + # Fix Bug #4: Check if eval dataloader is empty + if len(eval_loader) == 0: + logger.warning( + f"Evaluation dataloader is empty. Dataset size: {len(eval_dataset)}, " + f"Batch size: {self.config.eval_batch_size}. Skipping evaluation." + ) + return { + "eval_loss": 0.0, + "eval_classification_loss": 0.0, + "eval_structure_loss": 0.0, + "eval_count_loss": 0.0, + "step": self.global_step, + "epoch": self.epoch, + } + + total_loss = 0.0 + total_cls_loss = 0.0 + total_struct_loss = 0.0 + total_count_loss = 0.0 + num_batches = 0 + + use_amp = self.config.fp16 or self.config.bf16 + amp_dtype = torch.bfloat16 if self.config.bf16 else torch.float16 + + with torch.no_grad(): + for batch in tqdm(eval_loader, desc="Evaluating", disable=not self.is_main_process): + with autocast(enabled=use_amp, dtype=amp_dtype): + outputs = self.model(batch) + + # Fix Bug #10: Move tensors to CPU to prevent memory leak + total_loss += outputs["total_loss"].detach().cpu().item() + total_cls_loss += outputs.get("classification_loss", torch.tensor(0)).detach().cpu().item() + total_struct_loss += outputs.get("structure_loss", torch.tensor(0)).detach().cpu().item() + total_count_loss += outputs.get("count_loss", torch.tensor(0)).detach().cpu().item() + num_batches += 1 + + # Fix Bug #4: Safe division for evaluation metrics + metrics = { + "eval_loss": self._safe_divide(total_loss, num_batches, default=0.0), + "eval_classification_loss": self._safe_divide(total_cls_loss, num_batches, default=0.0), + "eval_structure_loss": self._safe_divide(total_struct_loss, num_batches, default=0.0), + "eval_count_loss": self._safe_divide(total_count_loss, num_batches, default=0.0), + "step": self.global_step, + "epoch": self.epoch, + } + + if self.compute_metrics: + metrics.update(self.compute_metrics(self.model, eval_dataset)) + + self._log_metrics(metrics, prefix="eval") + self.eval_metrics_history.append(metrics) + + metric_value = metrics.get(self.config.metric_for_best, metrics["eval_loss"]) + is_best = ( + (self.config.greater_is_better and metric_value > self.best_metric) or + (not self.config.greater_is_better and metric_value < self.best_metric) + ) + + if is_best: + self.best_metric = metric_value + if self.config.save_best: + self._save_checkpoint("best") + logger.info(f"New best {self.config.metric_for_best}: {self.best_metric:.4f}") + + return metrics + + def _check_early_stopping(self, metrics: Dict[str, float]) -> bool: + metric_value = metrics.get(self.config.metric_for_best, metrics["eval_loss"]) + if self.config.greater_is_better: + improved = metric_value > self.best_metric + self.config.early_stopping_threshold + else: + improved = metric_value < self.best_metric - self.config.early_stopping_threshold + + if improved: + self.patience_counter = 0 + else: + self.patience_counter += 1 + + return self.patience_counter >= self.config.early_stopping_patience + + def _log_metrics(self, metrics: Union[Dict, TrainingMetrics], prefix: str = ""): + """Log metrics with safe handling of edge cases.""" + if isinstance(metrics, TrainingMetrics): + metrics = metrics.to_dict() + + # Handle empty metrics gracefully + if not metrics: + logger.warning("Attempted to log empty metrics") + return + + # Update progress bar with key metrics + if self.is_main_process and self.progress_bar is not None: + postfix = {} + for key, value in metrics.items(): + if key in ["loss", "learning_rate", "throughput"]: + if isinstance(value, float): + if math.isnan(value): + postfix[key] = "NaN" + elif math.isinf(value): + postfix[key] = "Inf" + elif key == "learning_rate": + postfix["lr"] = f"{value:.2e}" + elif key == "throughput": + postfix["samples/s"] = f"{value:.1f}" + else: + postfix[key] = f"{value:.4f}" + + # Add epoch info if available + if "epoch" in metrics: + postfix["epoch"] = f"{metrics['epoch']:.1f}" + + if postfix: + self.progress_bar.set_postfix(postfix) + + # W&B logging + if self.config.report_to_wandb and self.is_main_process: + try: + import wandb + # Filter out NaN and Inf values for wandb + wandb_metrics = { + k: v + for k, v in metrics.items() + if isinstance(v, (int, float)) and not (math.isnan(v) or math.isinf(v)) + } + if wandb_metrics: + wandb.log(wandb_metrics, step=self.global_step) + except Exception as e: + logger.warning(f"Failed to log to wandb: {e}") + + if prefix == "train": + self.train_metrics_history.append(metrics) + + def _save_checkpoint(self, name: str): + if not self.is_main_process: + return + + checkpoint_dir = self.output_dir / name + checkpoint_dir.mkdir(exist_ok=True) + + save_start = time.time() + + # Handle adapter-only saves when using LoRA + if self.config.use_lora and self.config.save_adapter_only: + from gliner2.training.lora import save_lora_adapter + save_lora_adapter(self.model, checkpoint_dir) + checkpoint_type = "adapter" + trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + else: + # Full model save: merge LoRA weights if present + lora_was_merged = False + if self.config.use_lora and self.lora_layers: + first_lora_layer = next(iter(self.lora_layers.values())) + if not first_lora_layer.merged: + num_merged = merge_lora_weights(self.model) + lora_was_merged = True + + # Save the model (with merged weights if LoRA was used) + self.model.save_pretrained(str(checkpoint_dir)) + + # Unmerge weights after saving to continue training with LoRA + if lora_was_merged: + from gliner2.training.lora import unmerge_lora_weights + unmerge_lora_weights(self.model) + + # Save LoRA configuration if used + if self.config.use_lora: + lora_config_dict = { + "use_lora": True, + "lora_r": self.config.lora_r, + "lora_alpha": self.config.lora_alpha, + "lora_dropout": self.config.lora_dropout, + "lora_target_modules": self.config.lora_target_modules, + "merged": True, + } + import json + with open(checkpoint_dir / "lora_config.json", "w") as f: + json.dump(lora_config_dict, f, indent=2) + + checkpoint_type = "full" + trainable_params = sum(p.numel() for p in self.model.parameters()) + + save_time = time.time() - save_start + checkpoint_size_mb = sum(f.stat().st_size for f in checkpoint_dir.rglob('*') if f.is_file()) / (1024 * 1024) + + # World-class logging + logger.info( + f"πŸ’Ύ Saved {checkpoint_type} checkpoint '{name}' | " + f"step {self.global_step} | epoch {self.epoch + 1:.1f} | " + f"{trainable_params:,} params | {checkpoint_size_mb:.1f}MB | {save_time:.1f}s" + ) + + # Save model artifacts to W&B for best and final checkpoints + if self.config.report_to_wandb and name in ["best", "final"]: + try: + import wandb + artifact = wandb.Artifact( + name=f"model-{self.config.experiment_name}-{name}", + type="model", + metadata={ + "step": self.global_step, + "epoch": self.epoch, + "checkpoint_type": checkpoint_type, + "params": trainable_params, + "size_mb": checkpoint_size_mb, + } + ) + artifact.add_dir(str(checkpoint_dir)) + wandb.log_artifact(artifact) + except Exception as e: + logger.warning(f"W&B artifact upload failed: {e}") + + self._cleanup_checkpoints() + + def _cleanup_checkpoints(self): + if self.config.save_total_limit <= 0: + return + + checkpoints = sorted( + [d for d in self.output_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")], + key=lambda x: x.stat().st_mtime, + ) + protected = {"best", "final"} + checkpoints = [c for c in checkpoints if c.name not in protected] + + while len(checkpoints) > self.config.save_total_limit: + oldest = checkpoints.pop(0) + shutil.rmtree(oldest) + logger.info(f"Removed old checkpoint: {oldest.name}") + + def load_checkpoint(self, checkpoint_path: str): + """ + Load model weights from a checkpoint. + + Handles both adapter-only and full checkpoints. + Note: Training always starts fresh (no optimizer/scheduler state loaded). + """ + from gliner2.training.lora import LoRAAdapterConfig + + checkpoint_dir = Path(checkpoint_path) + + if LoRAAdapterConfig.is_adapter_path(checkpoint_path): + # Adapter checkpoint - load adapter onto existing model + logger.info(f"Loading LoRA adapter from {checkpoint_path}") + self.model.load_adapter(checkpoint_path) + self.lora_layers = self.model._lora_layers + else: + # Full model checkpoint + lora_config_path = checkpoint_dir / "lora_config.json" + if lora_config_path.exists(): + import json + with open(lora_config_path) as f: + lora_config = json.load(f) + logger.info( + f"Checkpoint has LoRA config (r={lora_config.get('lora_r')}, " + f"alpha={lora_config.get('lora_alpha')}, merged weights)" + ) + + # Load model (with merged weights if it was trained with LoRA) + self.model = self.model.__class__.from_pretrained(str(checkpoint_dir)) + self.model.to(self.device) + + # Re-apply LoRA if enabled in current config + if self.config.use_lora: + logger.info("Applying LoRA to loaded model...") + self.lora_layers = {} + self._setup_lora() + + logger.info(f"βœ“ Loaded checkpoint: {checkpoint_path}") + + +# ============================================================================= +# Convenience Functions +# ============================================================================= + +def train_gliner2( + model_path: str, + train_data: TrainDataInput, + output_dir: str = "./output", + eval_data: TrainDataInput = None, + **config_kwargs, +) -> Dict[str, Any]: + """ + Convenience function for training GLiNER2. + + Parameters + ---------- + model_path : str + Path to pretrained model. + train_data : TrainDataInput + Training data in any supported format: + - JSONL path(s) + - List of InputExample + - TrainingDataset + - List of dicts + output_dir : str + Output directory for checkpoints. + eval_data : TrainDataInput, optional + Evaluation data. + **config_kwargs + Additional TrainingConfig parameters. + + Returns + ------- + Dict[str, Any] + Training results. + + Examples + -------- + >>> # Train with JSONL file + >>> results = train_gliner2("model-path", "train.jsonl", num_epochs=10) + + >>> # Train with multiple JSONL files + >>> results = train_gliner2("model-path", ["train1.jsonl", "train2.jsonl"]) + + >>> # Train with InputExample list + >>> examples = [InputExample(...), ...] + >>> results = train_gliner2("model-path", examples) + + >>> # Train with TrainingDataset + >>> dataset = TrainingDataset.load("train.jsonl") + >>> results = train_gliner2("model-path", dataset) + """ + from gliner2 import GLiNER2 + + model = GLiNER2.from_pretrained(model_path) + config = TrainingConfig(output_dir=output_dir, **config_kwargs) + + trainer = GLiNER2Trainer(model=model, config=config) + return trainer.train(train_data=train_data, eval_data=eval_data) \ No newline at end of file diff --git a/packages/GLiNER2/pyproject.toml b/packages/GLiNER2/pyproject.toml new file mode 100644 index 0000000..1e45071 --- /dev/null +++ b/packages/GLiNER2/pyproject.toml @@ -0,0 +1,25 @@ +[build-system] +requires = ["setuptools>=61.0.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +include = ["gliner2", "gliner2.*"] + +[tool.setuptools.dynamic] +version = {attr = "gliner2.__version__"} + +[project] +name = "gliner2" +readme = "README.md" +requires-python = ">=3.8" + +maintainers = [ + {name = "Urchade Zaratiana"}, +] + +dependencies = [ + "gliner", + "pydantic>=2.0.0", +] + +dynamic = ["version"] \ No newline at end of file diff --git a/packages/GLiNER2/tutorial/1-classification.md b/packages/GLiNER2/tutorial/1-classification.md new file mode 100644 index 0000000..e3a2599 --- /dev/null +++ b/packages/GLiNER2/tutorial/1-classification.md @@ -0,0 +1,663 @@ +# GLiNER2 Classification Tutorial + +This tutorial covers all the ways to perform text classification with GLiNER2, from simple single-label classification to complex multi-label tasks with custom configurations. + +## Table of Contents +- [Setup](#setup) +- [Single-Label Classification](#single-label-classification) +- [Multi-Label Classification](#multi-label-classification) +- [Classification with Descriptions](#classification-with-descriptions) +- [Using the Quick API](#using-the-quick-api) +- [Multiple Classification Tasks](#multiple-classification-tasks) +- [Advanced Configurations](#advanced-configurations) +- [Best Practices](#best-practices) + +## Setup + +```python +from gliner2 import GLiNER2 + +# Load the pre-trained model +extractor = GLiNER2.from_pretrained("your-model-name") +``` + +## Single-Label Classification + +The simplest form - classify text into one of several categories. + +### Basic Example + +```python +# Define the schema +schema = extractor.create_schema().classification( + "sentiment", + ["positive", "negative", "neutral"] +) + +# Extract +text = "This product exceeded my expectations! Absolutely love it." +results = extractor.extract(text, schema) +print(results) +# Expected output: {'sentiment': 'positive'} +``` + +### With Confidence Scores + +```python +# Same schema as above +schema = extractor.create_schema().classification( + "sentiment", + ["positive", "negative", "neutral"] +) + +text = "The service was okay, nothing special but not bad either." +results = extractor.extract(text, schema, include_confidence=True) +print(results) +# Expected output: {'sentiment': {'label': 'neutral', 'confidence': 0.82}} +``` + +## Multi-Label Classification + +When text can belong to multiple categories simultaneously. + +```python +# Multi-label classification +schema = extractor.create_schema().classification( + "topics", + ["technology", "business", "health", "politics", "sports"], + multi_label=True, + cls_threshold=0.3 # Lower threshold for multi-label +) + +text = "Apple announced new health monitoring features in their latest smartwatch, boosting their stock price." +results = extractor.extract(text, schema) +print(results) +# Expected output: {'topics': ['technology', 'business', 'health']} + +# With confidence scores +results = extractor.extract(text, schema, include_confidence=True) +print(results) +# Expected output: {'topics': [ +# {'label': 'technology', 'confidence': 0.92}, +# {'label': 'business', 'confidence': 0.78}, +# {'label': 'health', 'confidence': 0.65} +# ]} +``` + +## Classification with Descriptions + +Adding descriptions significantly improves accuracy by providing context. + +```python +# With label descriptions +schema = extractor.create_schema().classification( + "document_type", + { + "invoice": "A bill for goods or services with payment details", + "receipt": "Proof of payment for a completed transaction", + "contract": "Legal agreement between parties with terms and conditions", + "proposal": "Document outlining suggested plans or services with pricing" + } +) + +text = "Please find attached the itemized bill for consulting services rendered in Q3 2024. Payment is due within 30 days." +results = extractor.extract(text, schema) +print(results) +# Expected output: {'document_type': 'invoice'} + +# Another example +text2 = "Thank you for your payment of $500. This confirms your transaction was completed on March 1st, 2024." +results2 = extractor.extract(text2, schema) +print(results2) +# Expected output: {'document_type': 'receipt'} +``` + +## Using the Quick API + +For simple classification tasks without building a schema. + +### Single Task + +```python +text = "The new AI model shows remarkable performance improvements." +results = extractor.classify_text( + text, + {"sentiment": ["positive", "negative", "neutral"]} +) +print(results) +# Expected output: {'sentiment': 'positive'} + +# Another example +text2 = "The software keeps crashing and customer support is unresponsive." +results2 = extractor.classify_text( + text2, + {"sentiment": ["positive", "negative", "neutral"]} +) +print(results2) +# Expected output: {'sentiment': 'negative'} +``` + +### Multiple Tasks + +```python +text = "Breaking: Tech giant announces major layoffs amid market downturn" +results = extractor.classify_text( + text, + { + "sentiment": ["positive", "negative", "neutral"], + "urgency": ["high", "medium", "low"], + "category": { + "labels": ["tech", "finance", "politics", "sports"], + "multi_label": False + } + } +) +print(results) +# Expected output: { +# 'sentiment': 'negative', +# 'urgency': 'high', +# 'category': 'tech' +# } +``` + +### Multi-Label with Config + +```python +text = "The smartphone features an amazing camera but disappointing battery life and overheats frequently." +results = extractor.classify_text( + text, + { + "product_aspects": { + "labels": ["camera", "battery", "display", "performance", "design", "heating"], + "multi_label": True, + "cls_threshold": 0.4 + } + } +) +print(results) +# Expected output: {'product_aspects': ['camera', 'battery', 'heating']} + +# Another example +text2 = "Beautiful design with vibrant display, though the camera could be better." +results2 = extractor.classify_text( + text2, + { + "product_aspects": { + "labels": ["camera", "battery", "display", "performance", "design", "heating"], + "multi_label": True, + "cls_threshold": 0.4 + } + } +) +print(results2) +# Expected output: {'product_aspects': ['design', 'display', 'camera']} +``` + +## Multiple Classification Tasks + +You can include multiple classification tasks in a single schema for comprehensive text analysis. + +### Basic Multiple Classifications + +```python +# Multiple independent classifications +schema = (extractor.create_schema() + .classification("sentiment", ["positive", "negative", "neutral"]) + .classification("language", ["english", "spanish", "french", "german", "other"]) + .classification("formality", ["formal", "informal", "semi-formal"]) + .classification("intent", ["question", "statement", "request", "complaint"]) +) + +text = "Could you please help me with my order? The service has been disappointing." +results = extractor.extract(text, schema) +print(results) +# Expected output: { +# 'sentiment': 'negative', +# 'language': 'english', +# 'formality': 'formal', +# 'intent': 'question' +# } + +# Another example +text2 = "Hey! Just wanted to say your product rocks! πŸŽ‰" +results2 = extractor.extract(text2, schema) +print(results2) +# Expected output: { +# 'sentiment': 'positive', +# 'language': 'english', +# 'formality': 'informal', +# 'intent': 'statement' +# } +``` + +### Mixed Single and Multi-Label Classifications + +```python +# Combine different classification types +schema = (extractor.create_schema() + # Single-label classifications + .classification("primary_topic", ["tech", "business", "health", "sports", "politics"]) + .classification("urgency", ["immediate", "soon", "later", "not_urgent"]) + + # Multi-label classifications + .classification("emotions", + ["happy", "sad", "angry", "surprised", "fearful", "disgusted"], + multi_label=True, + cls_threshold=0.4 + ) + .classification("content_flags", + ["inappropriate", "spam", "promotional", "personal_info", "financial_info"], + multi_label=True, + cls_threshold=0.3 + ) +) + +text = "URGENT: I'm thrilled to announce our new product! But concerned about competitor reactions. Please keep confidential." +results = extractor.extract(text, schema) +print(results) +# Expected output: { +# 'primary_topic': 'business', +# 'urgency': 'immediate', +# 'emotions': ['happy', 'fearful'], +# 'content_flags': ['promotional', 'personal_info'] +# } + +# Another example +text2 = "Just saw the game - absolutely devastated by the loss. Can't believe the referee's terrible decision!" +results2 = extractor.extract(text2, schema) +print(results2) +# Expected output: { +# 'primary_topic': 'sports', +# 'urgency': 'not_urgent', +# 'emotions': ['sad', 'angry'], +# 'content_flags': [] +# } +``` + +### Domain-Specific Multiple Classifications + +```python +# Customer support ticket classification +support_schema = (extractor.create_schema() + .classification("ticket_type", + ["technical_issue", "billing", "feature_request", "bug_report", "other"]) + .classification("priority", + ["critical", "high", "medium", "low"], + cls_threshold=0.7 + ) + .classification("product_area", + { + "authentication": "Login, passwords, security", + "payment": "Payment processing, subscriptions", + "ui": "User interface, design issues", + "performance": "Speed, loading, responsiveness", + "data": "Data loss, corruption, sync issues" + }, + multi_label=True, + cls_threshold=0.5 + ) + .classification("customer_sentiment", + ["very_satisfied", "satisfied", "neutral", "frustrated", "very_frustrated"], + cls_threshold=0.6 + ) + .classification("requires_action", + ["immediate_response", "investigation_needed", "waiting_customer", "resolved"], + multi_label=True + ) +) + +ticket_text = """ +Subject: Cannot login - Urgent! + +I've been trying to login for the past hour but keep getting error messages. +This is critical as I need to process payments for my customers today. +The page just keeps spinning and then times out. I'm extremely frustrated +as this is costing me business. Please fix this immediately! +""" + +results = extractor.extract(ticket_text, support_schema) +print(results) +# Expected output: { +# 'ticket_type': 'technical_issue', +# 'priority': 'critical', +# 'product_area': ['authentication', 'payment', 'performance'], +# 'customer_sentiment': 'very_frustrated', +# 'requires_action': ['immediate_response', 'investigation_needed'] +# } + +# Another support ticket example +ticket_text2 = """ +Hi team, + +Thanks for the great product! I was wondering if you could add a dark mode feature? +It would really help with eye strain during late night work sessions. + +Best regards, +Happy Customer +""" + +results2 = extractor.extract(ticket_text2, support_schema) +print(results2) +# Expected output: { +# 'ticket_type': 'feature_request', +# 'priority': 'low', +# 'product_area': ['ui'], +# 'customer_sentiment': 'satisfied', +# 'requires_action': ['waiting_customer'] +# } +``` + +### Sequential Classification with Dependencies + +```python +# Email routing and handling classification +email_schema = (extractor.create_schema() + # Primary classification + .classification("email_category", + ["sales", "support", "hr", "legal", "general"], + cls_threshold=0.6 + ) + + # Secondary classifications based on context + .classification("sales_stage", + ["lead", "qualified", "proposal", "negotiation", "closed"], + cls_threshold=0.5 + ) + .classification("support_type", + ["pre_sales", "technical", "account", "billing"], + cls_threshold=0.5 + ) + + # Action classifications + .classification("required_action", + ["reply_needed", "forward_to_team", "schedule_meeting", "no_action"], + multi_label=True, + cls_threshold=0.4 + ) + .classification("response_timeframe", + ["within_1_hour", "within_24_hours", "within_week", "non_urgent"], + cls_threshold=0.6 + ) +) + +email = """ +Hi Sales Team, + +I'm interested in your enterprise solution. We're currently evaluating vendors +for our upcoming project. Could we schedule a demo next week? We need to make +a decision by month end. + +Best regards, +John from TechCorp +""" + +results = extractor.extract(email, email_schema) +print(results) +# Expected output: { +# 'email_category': 'sales', +# 'sales_stage': 'qualified', +# 'support_type': 'pre_sales', +# 'required_action': ['reply_needed', 'schedule_meeting'], +# 'response_timeframe': 'within_24_hours' +# } + +# HR email example +email2 = """ +Dear HR Department, + +I need to update my tax withholding information. Could someone please send me +the necessary forms? This is somewhat urgent as I need this changed before the +next payroll cycle. + +Thank you, +Sarah +""" + +results2 = extractor.extract(email2, email_schema) +print(results2) +# Expected output: { +# 'email_category': 'hr', +# 'sales_stage': 'lead', # May have noise in non-sales emails +# 'support_type': 'account', +# 'required_action': ['reply_needed'], +# 'response_timeframe': 'within_24_hours' +# } +``` + +### Complex Analysis with Multiple Classifications + +```python +# Content moderation and analysis +content_schema = (extractor.create_schema() + # Content classifications + .classification("content_type", + ["article", "comment", "review", "social_post", "message"]) + .classification("primary_language", + ["english", "spanish", "french", "other"]) + + # Quality assessments + .classification("quality_score", + ["excellent", "good", "average", "poor", "spam"], + cls_threshold=0.7 + ) + .classification("originality", + ["original", "derivative", "duplicate", "plagiarized"], + cls_threshold=0.8 + ) + + # Safety and compliance + .classification("safety_flags", + { + "hate_speech": "Contains discriminatory or hateful content", + "violence": "Contains violent or threatening content", + "adult": "Contains adult or explicit content", + "misinformation": "Contains potentially false information", + "personal_info": "Contains personal identifying information" + }, + multi_label=True, + cls_threshold=0.3 + ) + + # Engagement predictions + .classification("engagement_potential", + ["viral", "high", "medium", "low"], + cls_threshold=0.6 + ) + .classification("audience_fit", + ["general", "professional", "academic", "youth", "senior"], + multi_label=True, + cls_threshold=0.5 + ) +) + +content_text = """ +Just discovered this amazing productivity hack that doubled my output! +Here's what I do: I wake up at 5 AM, meditate for 20 minutes, then work +in 90-minute focused blocks. The results have been incredible. My email +is john.doe@example.com if you want more tips! +""" + +results = extractor.extract(content_text, content_schema) +print(results) +# Expected output: { +# 'content_type': 'social_post', +# 'primary_language': 'english', +# 'quality_score': 'good', +# 'originality': 'original', +# 'safety_flags': ['personal_info'], +# 'engagement_potential': 'high', +# 'audience_fit': ['general', 'professional'] +# } + +# Review example +review_text = """ +Worst product ever!!! Total scam! Don't buy this garbage. The company should +be shut down for selling this junk. I'm going to report them to authorities. +""" + +results2 = extractor.extract(review_text, content_schema) +print(results2) +# Expected output: { +# 'content_type': 'review', +# 'primary_language': 'english', +# 'quality_score': 'poor', +# 'originality': 'original', +# 'safety_flags': ['violence'], # Due to aggressive language +# 'engagement_potential': 'low', +# 'audience_fit': ['general'] +# } +``` + +## Advanced Configurations + +### Custom Thresholds + +```python +# High-precision classification +schema = extractor.create_schema().classification( + "is_spam", + ["spam", "not_spam"], + cls_threshold=0.9 # Very high confidence required +) + +text = "Congratulations! You've won $1,000,000! Click here to claim your prize now!" +results = extractor.extract(text, schema) +print(results) +# Expected output: {'is_spam': 'spam'} + +# Different thresholds for different tasks +schema = (extractor.create_schema() + .classification("priority", ["urgent", "high", "normal", "low"], cls_threshold=0.8) + .classification("department", ["sales", "support", "billing", "other"], cls_threshold=0.5) +) + +text = "URGENT: Customer threatening to cancel $50k contract due to billing error" +results = extractor.extract(text, schema) +print(results) +# Expected output: { +# 'priority': 'urgent', +# 'department': 'billing' +# } +``` + +### Custom Activation Functions + +```python +# Force specific activation +schema = extractor.create_schema().classification( + "category", + ["A", "B", "C", "D"], + class_act="softmax" # Options: "sigmoid", "softmax", "auto" +) + +text = "This clearly belongs to category B based on the criteria." +results = extractor.extract(text, schema) +print(results) +# Expected output: {'category': 'B'} +``` + +### Complex Multi-Label Example + +```python +# Email classification system +schema = extractor.create_schema().classification( + "email_tags", + { + "action_required": "Email requires recipient to take action", + "meeting_request": "Email contains meeting invitation or scheduling", + "project_update": "Email contains project status or updates", + "urgent": "Email marked as urgent or time-sensitive", + "question": "Email contains questions requiring answers", + "fyi": "Informational email requiring no action" + }, + multi_label=True, + cls_threshold=0.35 +) + +email_text = """ +Hi team, + +Quick update on Project Alpha: We're ahead of schedule! + +However, I need your input on the design mockups by EOD tomorrow. +Can we schedule a 30-min call this week to discuss? + +This is quite urgent as the client is waiting. + +Best, +Sarah +""" + +results = extractor.extract(email_text, schema) +print(results) +# Expected output: { +# 'email_tags': ['action_required', 'meeting_request', 'project_update', 'urgent', 'question'] +# } + +# FYI email example +email_text2 = """ +Team, + +Just wanted to let everyone know that I'll be out of office next Monday for a +doctor's appointment. I'll be back Tuesday morning. + +Thanks, +Mark +""" + +results2 = extractor.extract(email_text2, schema) +print(results2) +# Expected output: { +# 'email_tags': ['fyi'] +# } +``` + +## Best Practices + +1. **Use Descriptions**: Always provide label descriptions when possible + ```python + # Good - with descriptions + schema = extractor.create_schema().classification( + "intent", + { + "purchase": "User wants to buy a product", + "return": "User wants to return a product", + "inquiry": "User asking for information" + } + ) + + # Less effective - no context + schema = extractor.create_schema().classification( + "intent", + ["purchase", "return", "inquiry"] + ) + ``` + +2. **Adjust Thresholds**: Lower thresholds for multi-label (0.3-0.5), higher for single-label (0.5-0.7) + +3. **Multi-Label Strategy**: Use multi-label when categories aren't mutually exclusive + ```python + # Good use of multi-label + schema = extractor.create_schema().classification( + "product_features", + ["waterproof", "wireless", "rechargeable", "portable"], + multi_label=True + ) + + # Should be single-label + schema = extractor.create_schema().classification( + "size", + ["small", "medium", "large"], + multi_label=False # Sizes are mutually exclusive + ) + ``` + +4. **Test with Real Examples**: Always test with actual text samples from your domain + +## Common Use Cases + +- **Sentiment Analysis**: Customer feedback, reviews, social media +- **Intent Classification**: Chatbots, customer service routing +- **Document Classification**: Email filtering, document management +- **Content Moderation**: Toxic content, spam detection +- **Topic Classification**: News categorization, content tagging \ No newline at end of file diff --git a/packages/GLiNER2/tutorial/10-lora_adapters.md b/packages/GLiNER2/tutorial/10-lora_adapters.md new file mode 100644 index 0000000..fa18a9f --- /dev/null +++ b/packages/GLiNER2/tutorial/10-lora_adapters.md @@ -0,0 +1,973 @@ +# Tutorial 10: LoRA Adapters - Multi-Domain Inference + +## Table of Contents +1. [Introduction](#introduction) +2. [Why Use LoRA Adapters?](#why-use-lora-adapters) +3. [Training Your First Adapter](#training-your-first-adapter) +4. [Training Multiple Domain Adapters](#training-multiple-domain-adapters) +5. [Loading and Swapping Adapters](#loading-and-swapping-adapters) +6. [Real-World Use Cases](#real-world-use-cases) +7. [Best Practices](#best-practices) +8. [Troubleshooting](#troubleshooting) + +## Introduction + +LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning technique that allows you to train specialized adapters for different domains without modifying the base model. This enables: + +- **Fast domain switching**: Swap between domains in milliseconds +- **Minimal storage**: Adapters are ~2-10 MB vs ~100-500 MB for full models +- **Domain specialization**: Train separate adapters for legal, medical, financial, etc. +- **Easy deployment**: Keep one base model + multiple lightweight adapters + +## Why Use LoRA Adapters? + +### Memory Efficiency + +``` +Full Model Fine-tuning: +- Legal model: 450 MB +- Medical model: 450 MB +- Financial model: 450 MB +Total: 1.35 GB + +LoRA Adapters: +- Base model: 450 MB +- Legal adapter: 5 MB +- Medical adapter: 5 MB +- Financial adapter: 5 MB +Total: 465 MB (65% less!) +``` + +### Fast Training + +LoRA adapters train **2-3x faster** than full fine-tuning because: +- Only ~1-5% of parameters are trainable +- Smaller gradient computations +- Less GPU memory required + +### Easy Multi-Domain Inference + +```python +# One base model, multiple domains +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") + +# Legal domain +model.load_adapter("./legal_adapter") +legal_results = model.extract_entities(legal_text, ["company", "law"]) + +# Medical domain (swap in <1 second) +model.load_adapter("./medical_adapter") +medical_results = model.extract_entities(medical_text, ["disease", "drug"]) +``` + +## Training Your First Adapter + +### Step 1: Prepare Domain-Specific Data + +```python +from gliner2.training.data import InputExample + +# Legal domain examples +legal_examples = [ + InputExample( + text="Apple Inc. filed a lawsuit against Samsung Electronics.", + entities={"company": ["Apple Inc.", "Samsung Electronics"]} + ), + InputExample( + text="The plaintiff Google LLC accused Microsoft Corporation of patent infringement.", + entities={"company": ["Google LLC", "Microsoft Corporation"]} + ), + InputExample( + text="Tesla Motors settled the case with the Securities and Exchange Commission.", + entities={ + "company": ["Tesla Motors"], + "organization": ["Securities and Exchange Commission"] + } + ), + # Add 100-1000+ examples for best results +] +``` + +### Step 2: Configure LoRA Training + +```python +from gliner2 import GLiNER2 +from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig + +# LoRA configuration +config = TrainingConfig( + output_dir="./legal_adapter", + experiment_name="legal_domain", + + # Training parameters + num_epochs=10, + batch_size=8, + gradient_accumulation_steps=2, + encoder_lr=1e-5, + task_lr=5e-4, + + # LoRA settings + use_lora=True, # Enable LoRA + lora_r=8, # Rank (4, 8, 16, 32) + lora_alpha=16.0, # Scaling factor (usually 2*r) + lora_dropout=0.0, # Dropout for LoRA layers + lora_target_modules=["encoder"], # Apply to all encoder layers (query, key, value, dense) + save_adapter_only=True, # Save only adapter (not full model) + + # Optimization + eval_strategy="epoch", # Evaluates and saves at end of each epoch + eval_steps=500, # Used when eval_strategy="steps" + logging_steps=50, + fp16=True, # Use mixed precision if GPU available +) +``` + +### Step 3: Train the Adapter + +```python +# Load base model +base_model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") + +# Create trainer +trainer = GLiNER2Trainer(model=base_model, config=config) + +# Train adapter +trainer.train(train_data=legal_examples) + +# Adapter automatically saved to ./legal_adapter/final/ +``` + +**Training output:** +``` +πŸ”§ LoRA Configuration +====================================================================== +Enabled : True +Rank (r) : 8 +Alpha : 16.0 +Scaling (Ξ±/r) : 2.0000 +Dropout : 0.0 +Target modules : query, key, value, dense +LoRA layers : 144 +---------------------------------------------------------------------- +Trainable params : 1,327,104 / 124,442,368 (1.07%) +Memory savings : ~98.9% fewer gradients +====================================================================== + +***** Running Training ***** + Num examples = 1000 + Num epochs = 10 + Batch size = 8 + Effective batch size = 16 + Total optimization steps = 625 + LoRA enabled: 1,327,104 trainable / 124,442,368 total (1.07%) +``` + +## Training Multiple Domain Adapters + +Let's train adapters for three different domains: **Legal**, **Medical**, and **Customer Support**. + +### Complete Multi-Domain Training Script + +```python +from gliner2 import GLiNER2 +from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig +from gliner2.training.data import InputExample + +# ============================================================================ +# Define Domain Data +# ============================================================================ + +# Legal domain +legal_examples = [ + InputExample( + text="Apple Inc. filed a lawsuit against Samsung Electronics.", + entities={"company": ["Apple Inc.", "Samsung Electronics"]} + ), + InputExample( + text="The plaintiff Google LLC accused Microsoft Corporation of patent infringement.", + entities={"company": ["Google LLC", "Microsoft Corporation"]} + ), + # Add more examples... +] + +# Medical domain +medical_examples = [ + InputExample( + text="Patient diagnosed with Type 2 Diabetes and Hypertension.", + entities={"disease": ["Type 2 Diabetes", "Hypertension"]} + ), + InputExample( + text="Prescribed Metformin 500mg twice daily and Lisinopril 10mg once daily.", + entities={ + "drug": ["Metformin", "Lisinopril"], + "dosage": ["500mg", "10mg"] + } + ), + # Add more examples... +] + +# Customer support domain +support_examples = [ + InputExample( + text="Customer John Smith reported issue with Order #12345.", + entities={ + "customer": ["John Smith"], + "order_id": ["Order #12345"] + } + ), + InputExample( + text="Refund of $99.99 processed for Order #98765 on 2024-01-15.", + entities={ + "order_id": ["Order #98765"], + "amount": ["$99.99"], + "date": ["2024-01-15"] + } + ), + # Add more examples... +] + +# ============================================================================ +# Training Function +# ============================================================================ + +def train_domain_adapter( + base_model_name: str, + examples: list, + domain_name: str, + output_dir: str = "./adapters" +): + """Train a LoRA adapter for a specific domain.""" + + adapter_path = f"{output_dir}/{domain_name}_adapter" + + config = TrainingConfig( + output_dir=adapter_path, + experiment_name=f"{domain_name}_domain", + + # Training + num_epochs=10, + batch_size=8, + gradient_accumulation_steps=2, + encoder_lr=1e-5, + task_lr=5e-4, + + # LoRA + use_lora=True, + lora_r=8, + lora_alpha=16.0, + lora_dropout=0.0, + lora_target_modules=["encoder"], # All encoder layers + save_adapter_only=True, + + # Logging & Checkpointing + eval_strategy="no", # Set to "epoch" or "steps" if you have validation set + eval_steps=500, # Used when eval_strategy="steps" + logging_steps=50, + fp16=True, + ) + + # Load base model + print(f"\n{'='*60}") + print(f"Training {domain_name.upper()} adapter") + print(f"{'='*60}") + + model = GLiNER2.from_pretrained(base_model_name) + trainer = GLiNER2Trainer(model=model, config=config) + + # Train + results = trainer.train(train_data=examples) + + print(f"\nβœ… {domain_name.capitalize()} adapter trained!") + print(f"πŸ“ Saved to: {adapter_path}/final/") + print(f"⏱️ Training time: {results['total_time_seconds']:.2f}s") + + return f"{adapter_path}/final" + +# ============================================================================ +# Train All Adapters +# ============================================================================ + +if __name__ == "__main__": + BASE_MODEL = "fastino/gliner2-base-v1" + + # Train adapters for each domain + legal_adapter_path = train_domain_adapter( + BASE_MODEL, legal_examples, "legal" + ) + + medical_adapter_path = train_domain_adapter( + BASE_MODEL, medical_examples, "medical" + ) + + support_adapter_path = train_domain_adapter( + BASE_MODEL, support_examples, "support" + ) + + print("\n" + "="*60) + print("πŸŽ‰ All adapters trained successfully!") + print("="*60) + print(f"Legal adapter: {legal_adapter_path}") + print(f"Medical adapter: {medical_adapter_path}") + print(f"Support adapter: {support_adapter_path}") +``` + +## Loading and Swapping Adapters + +### Basic Usage + +```python +from gliner2 import GLiNER2 + +# Load base model once +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") + +# Load legal adapter +model.load_adapter("./adapters/legal_adapter/final") + +# Use the model +result = model.extract_entities( + "Apple Inc. sued Samsung over patent rights.", + ["company", "legal_action"] +) +print(result) +``` + +### Swapping Between Adapters + +```python +# Load base model +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") + +# Legal domain +print("πŸ“‹ Legal Analysis:") +model.load_adapter("./adapters/legal_adapter/final") +legal_text = "Google LLC filed a complaint against Oracle Corporation." +legal_result = model.extract_entities(legal_text, ["company", "legal_action"]) +print(f" {legal_result}") + +# Swap to medical domain +print("\nπŸ₯ Medical Analysis:") +model.load_adapter("./adapters/medical_adapter/final") +medical_text = "Patient presents with Pneumonia and was prescribed Amoxicillin." +medical_result = model.extract_entities(medical_text, ["disease", "drug"]) +print(f" {medical_result}") + +# Swap to support domain +print("\nπŸ’¬ Support Analysis:") +model.load_adapter("./adapters/support_adapter/final") +support_text = "Customer reported Order #12345 not delivered on time." +support_result = model.extract_entities(support_text, ["order_id", "issue"]) +print(f" {support_result}") + +# Use base model without adapter +print("\nπŸ”§ Base Model (no adapter):") +model.unload_adapter() +base_result = model.extract_entities("Some generic text", ["entity"]) +print(f" {base_result}") +``` + +**Output:** +``` +πŸ“‹ Legal Analysis: + {'entities': [{'text': 'Google LLC', 'label': 'company', ...}, + {'text': 'Oracle Corporation', 'label': 'company', ...}]} + +πŸ₯ Medical Analysis: + {'entities': [{'text': 'Pneumonia', 'label': 'disease', ...}, + {'text': 'Amoxicillin', 'label': 'drug', ...}]} + +πŸ’¬ Support Analysis: + {'entities': [{'text': 'Order #12345', 'label': 'order_id', ...}]} + +πŸ”§ Base Model (no adapter): + {'entities': [{'text': 'text', 'label': 'entity', ...}]} +``` + +### Batch Processing with Adapter Swapping + +```python +def process_documents_by_domain(model, documents_by_domain, adapters): + """ + Process multiple documents across different domains efficiently. + + Args: + model: Base GLiNER2 model + documents_by_domain: Dict[domain_name, List[document_text]] + adapters: Dict[domain_name, adapter_path] + + Returns: + Dict[domain_name, List[results]] + """ + results = {} + + for domain, documents in documents_by_domain.items(): + print(f"Processing {domain} domain ({len(documents)} documents)...") + + # Load domain-specific adapter + model.load_adapter(adapters[domain]) + + # Process all documents for this domain + domain_results = [] + for doc in documents: + result = model.extract_entities(doc, get_entity_types(domain)) + domain_results.append(result) + + results[domain] = domain_results + + return results + +def get_entity_types(domain): + """Get entity types for each domain.""" + types = { + "legal": ["company", "person", "law", "legal_action"], + "medical": ["disease", "drug", "symptom", "procedure"], + "support": ["customer", "order_id", "product", "issue"] + } + return types.get(domain, ["entity"]) + +# Example usage +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") + +documents_by_domain = { + "legal": [ + "Apple Inc. filed suit against Samsung.", + "Microsoft acquired LinkedIn for $26B.", + ], + "medical": [ + "Patient has Type 2 Diabetes.", + "Prescribed Metformin 500mg daily.", + ], + "support": [ + "Issue with Order #12345 reported.", + "Refund processed for Order #98765.", + ] +} + +adapters = { + "legal": "./adapters/legal_adapter/final", + "medical": "./adapters/medical_adapter/final", + "support": "./adapters/support_adapter/final", +} + +results = process_documents_by_domain(model, documents_by_domain, adapters) + +# Results organized by domain +for domain, domain_results in results.items(): + print(f"\n{domain.upper()} Results:") + for i, result in enumerate(domain_results, 1): + print(f" Document {i}: {len(result['entities'])} entities found") +``` + +## Real-World Use Cases + +### Use Case 1: Multi-Tenant SaaS Platform + +```python +class MultiTenantEntityExtractor: + """Entity extraction service for multi-tenant platform.""" + + def __init__(self, base_model_name: str, tenant_adapters: dict): + """ + Args: + base_model_name: Path to base model + tenant_adapters: Dict mapping tenant_id to adapter_path + """ + self.model = GLiNER2.from_pretrained(base_model_name) + self.tenant_adapters = tenant_adapters + self.current_tenant = None + + def extract_for_tenant(self, tenant_id: str, text: str, entity_types: list): + """Extract entities for specific tenant.""" + # Load tenant-specific adapter if needed + if self.current_tenant != tenant_id: + adapter_path = self.tenant_adapters.get(tenant_id) + if adapter_path: + self.model.load_adapter(adapter_path) + else: + self.model.unload_adapter() # Use base model + self.current_tenant = tenant_id + + return self.model.extract_entities(text, entity_types) + +# Setup +extractor = MultiTenantEntityExtractor( + base_model_name="fastino/gliner2-base-v1", + tenant_adapters={ + "legal_firm_123": "./adapters/legal_adapter/final", + "hospital_456": "./adapters/medical_adapter/final", + "ecommerce_789": "./adapters/support_adapter/final", + } +) + +# Usage +legal_result = extractor.extract_for_tenant( + "legal_firm_123", + "Apple sued Samsung", + ["company"] +) + +medical_result = extractor.extract_for_tenant( + "hospital_456", + "Patient has diabetes", + ["disease"] +) +``` + +### Use Case 2: Document Classification Pipeline + +```python +def classify_and_extract(document: str, model: GLiNER2, adapters: dict): + """ + Classify document type and extract relevant entities. + + 1. Classify document type using base model + 2. Load appropriate domain adapter + 3. Extract domain-specific entities + """ + # Step 1: Classify document type + doc_type_result = model.extract_entities( + document, + ["legal_document", "medical_record", "support_ticket", "financial_report"] + ) + + # Determine document type + if doc_type_result['entities']: + doc_type = doc_type_result['entities'][0]['label'] + doc_type = doc_type.replace("_document", "").replace("_record", "").replace("_ticket", "").replace("_report", "") + else: + doc_type = "general" + + # Step 2: Load appropriate adapter + adapter_mapping = { + "legal": adapters.get("legal"), + "medical": adapters.get("medical"), + "support": adapters.get("support"), + "financial": adapters.get("financial"), + } + + if doc_type in adapter_mapping and adapter_mapping[doc_type]: + model.load_adapter(adapter_mapping[doc_type]) + + # Step 3: Extract domain-specific entities + entity_types = { + "legal": ["company", "person", "law", "legal_action"], + "medical": ["disease", "drug", "symptom", "procedure", "dosage"], + "support": ["customer", "order_id", "product", "issue", "status"], + "financial": ["company", "amount", "date", "stock_symbol"], + } + + entities = model.extract_entities( + document, + entity_types.get(doc_type, ["entity"]) + ) + + return { + "document_type": doc_type, + "entities": entities['entities'] + } + +# Usage +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") + +adapters = { + "legal": "./adapters/legal_adapter/final", + "medical": "./adapters/medical_adapter/final", + "support": "./adapters/support_adapter/final", +} + +document = "Patient John Smith diagnosed with Type 2 Diabetes on 2024-01-15." +result = classify_and_extract(document, model, adapters) + +print(f"Document Type: {result['document_type']}") +print(f"Entities: {result['entities']}") +``` + +### Use Case 3: A/B Testing Adapters + +```python +import random + +class AdapterABTester: + """A/B test different adapter versions.""" + + def __init__(self, base_model_name: str, adapter_variants: dict): + """ + Args: + adapter_variants: {"v1": path1, "v2": path2, ...} + """ + self.model = GLiNER2.from_pretrained(base_model_name) + self.adapter_variants = adapter_variants + self.results = {variant: [] for variant in adapter_variants} + + def test_sample(self, text: str, entity_types: list, true_entities: list): + """Test a sample with all adapter variants.""" + sample_results = {} + + for variant, adapter_path in self.adapter_variants.items(): + # Load variant + self.model.load_adapter(adapter_path) + + # Get predictions + pred = self.model.extract_entities(text, entity_types) + + # Compute metrics + f1 = self.compute_f1(pred['entities'], true_entities) + + sample_results[variant] = { + "predictions": pred['entities'], + "f1_score": f1 + } + + self.results[variant].append(f1) + + return sample_results + + def compute_f1(self, predicted, ground_truth): + """Simple F1 computation (simplified for demo).""" + pred_set = {(e['text'], e['label']) for e in predicted} + true_set = {(e['text'], e['label']) for e in ground_truth} + + if not pred_set and not true_set: + return 1.0 + if not pred_set or not true_set: + return 0.0 + + tp = len(pred_set & true_set) + precision = tp / len(pred_set) if pred_set else 0 + recall = tp / len(true_set) if true_set else 0 + + if precision + recall == 0: + return 0.0 + return 2 * precision * recall / (precision + recall) + + def get_summary(self): + """Get A/B test summary.""" + summary = {} + for variant, scores in self.results.items(): + if scores: + summary[variant] = { + "avg_f1": sum(scores) / len(scores), + "samples": len(scores) + } + return summary + +# Usage +tester = AdapterABTester( + base_model_name="fastino/gliner2-base-v1", + adapter_variants={ + "v1_r4": "./adapters/legal_v1_r4/final", + "v2_r8": "./adapters/legal_v2_r8/final", + "v3_r16": "./adapters/legal_v3_r16/final", + } +) + +# Test samples +test_samples = [ + { + "text": "Apple Inc. sued Samsung Electronics.", + "entity_types": ["company"], + "true_entities": [ + {"text": "Apple Inc.", "label": "company"}, + {"text": "Samsung Electronics", "label": "company"} + ] + }, + # More samples... +] + +for sample in test_samples: + results = tester.test_sample( + sample["text"], + sample["entity_types"], + sample["true_entities"] + ) + +# Get summary +summary = tester.get_summary() +for variant, metrics in summary.items(): + print(f"{variant}: Avg F1 = {metrics['avg_f1']:.3f} ({metrics['samples']} samples)") +``` + +## Best Practices + +### 1. Choosing LoRA Hyperparameters + +```python +# Small datasets (< 1K examples) +config = TrainingConfig( + lora_r=4, # Lower rank = fewer parameters + lora_alpha=8.0, # alpha = 2 * r + num_epochs=10, +) + +# Medium datasets (1K-10K examples) +config = TrainingConfig( + lora_r=8, # Standard rank + lora_alpha=16.0, + num_epochs=5, +) + +# Large datasets (> 10K examples) +config = TrainingConfig( + lora_r=16, # Higher rank = more capacity + lora_alpha=32.0, + num_epochs=3, +) +``` + +### 2. Target Module Selection + +**Understanding Module Groups:** + +GLiNER2 supports fine-grained control over which layers receive LoRA adaptation: + +```python +# Option 1: Encoder only - all layers (query, key, value, dense) +# Use case: General domain adaptation, good starting point +# Memory: Moderate (~1-2% of model parameters) +lora_target_modules=["encoder"] + +# Option 2: Encoder - attention layers only +# Use case: Very memory-constrained scenarios +# Memory: Low (~0.5-1% of model parameters) +lora_target_modules=["encoder.query", "encoder.key", "encoder.value"] + +# Option 3: Encoder - FFN layers only +# Use case: Alternative to attention-only, sometimes better for certain tasks +# Memory: Low (~0.5-1% of model parameters) +lora_target_modules=["encoder.dense"] + +# Option 4: Encoder + task heads +# Use case: When you want to adapt both representation and task-specific layers +# Memory: Moderate-High (~2-4% of model parameters) +lora_target_modules=["encoder", "span_rep", "classifier"] + +# Option 5: All modules (DEFAULT) +# Use case: Maximum adaptation capacity, best performance +# Memory: High (~3-5% of model parameters) +lora_target_modules=["encoder", "span_rep", "classifier", "count_embed", "count_pred"] +``` + +**Recommendations:** +- **Start with encoder only** (`["encoder"]`) for most tasks +- **Add task heads** if performance is insufficient +- **Use attention-only** for extreme memory constraints +- **Use all modules** (default) when you need maximum performance + +### 3. Adapter Organization + +``` +project/ +β”œβ”€β”€ base_model/ +β”‚ └── gliner2-base-v1/ +β”œβ”€β”€ adapters/ +β”‚ β”œβ”€β”€ legal/ +β”‚ β”‚ β”œβ”€β”€ v1_r8/ +β”‚ β”‚ β”‚ └── final/ +β”‚ β”‚ └── v2_r16/ +β”‚ β”‚ └── final/ +β”‚ β”œβ”€β”€ medical/ +β”‚ β”‚ └── final/ +β”‚ └── support/ +β”‚ └── final/ +└── scripts/ + β”œβ”€β”€ train_adapters.py + └── evaluate_adapters.py +``` + +### 4. Version Control for Adapters + +```python +# adapter_metadata.json +{ + "legal_v1": { + "path": "./adapters/legal/v1_r8/final", + "base_model": "fastino/gliner2-base-v1", + "lora_r": 8, + "lora_alpha": 16.0, + "trained_on": "2024-01-15", + "training_samples": 5000, + "eval_f1": 0.87, + "notes": "Initial legal domain adapter" + }, + "legal_v2": { + "path": "./adapters/legal/v2_r16/final", + "base_model": "fastino/gliner2-base-v1", + "lora_r": 16, + "lora_alpha": 32.0, + "trained_on": "2024-02-01", + "training_samples": 10000, + "eval_f1": 0.92, + "notes": "Improved with more data and higher rank" + } +} +``` + +### 5. Monitoring Adapter Performance + +```python +def evaluate_adapter(model, adapter_path, test_data): + """Evaluate adapter performance on test data.""" + model.load_adapter(adapter_path) + + results = { + "total": 0, + "correct": 0, + "precision_sum": 0, + "recall_sum": 0, + } + + for sample in test_data: + pred = model.extract_entities(sample["text"], sample["entity_types"]) + + # Compute metrics + metrics = compute_metrics(pred['entities'], sample["true_entities"]) + results["total"] += 1 + results["precision_sum"] += metrics["precision"] + results["recall_sum"] += metrics["recall"] + + avg_precision = results["precision_sum"] / results["total"] + avg_recall = results["recall_sum"] / results["total"] + f1 = 2 * avg_precision * avg_recall / (avg_precision + avg_recall) + + return { + "precision": avg_precision, + "recall": avg_recall, + "f1": f1, + "samples": results["total"] + } +``` + +## Troubleshooting + +### Issue 1: Adapter Not Affecting Predictions + +**Symptom**: Predictions are the same with and without adapter. + +**Solution**: +```python +# Check if adapter is actually loaded +print(f"Has adapter: {model.has_adapter}") + +# Check LoRA layers +from gliner2.training.lora import LoRALayer +lora_count = sum(1 for m in model.modules() if isinstance(m, LoRALayer)) +print(f"LoRA layers: {lora_count}") + +# Should be > 0 if adapter is loaded +assert lora_count > 0, "No LoRA layers found!" +``` + +### Issue 2: Out of Memory During Training + +**Solution**: +```python +config = TrainingConfig( + # Reduce batch size + batch_size=4, # Instead of 8 + gradient_accumulation_steps=4, # Maintain effective batch size + + # Use smaller LoRA rank + lora_r=4, # Instead of 8 + + # Enable mixed precision + fp16=True, + + # Target only attention layers (fewer parameters) + lora_target_modules=["encoder.query", "encoder.key", "encoder.value"], +) +``` + +### Issue 3: Adapter File Not Found + +**Solution**: +```python +import os +from gliner2.training.lora import LoRAAdapterConfig + +adapter_path = "./adapters/legal_adapter/final" + +# Check if path exists +if not os.path.exists(adapter_path): + print(f"Path does not exist: {adapter_path}") + # List available checkpoints + checkpoint_dir = "./adapters/legal_adapter" + if os.path.exists(checkpoint_dir): + checkpoints = os.listdir(checkpoint_dir) + print(f"Available checkpoints: {checkpoints}") + +# Check if it's a valid adapter +if LoRAAdapterConfig.is_adapter_path(adapter_path): + print("Valid adapter path!") + config = LoRAAdapterConfig.load(adapter_path) + print(f"Adapter config: {config}") +else: + print("Not a valid adapter path!") +``` + +### Issue 4: Slow Adapter Switching + +**Problem**: Switching between adapters takes too long. + +**Solution**: +```python +# Pre-load adapters in memory (if you have enough RAM) +adapters = {} +for domain, path in adapter_paths.items(): + # Load adapter weights into memory + adapters[domain] = load_adapter_to_memory(path) + +# Fast switching from memory (not implemented in base API, +# but possible with custom caching layer) +``` + +## Summary + +### Key Takeaways + +βœ… **LoRA adapters** enable efficient multi-domain inference +βœ… **Training** is 2-3x faster than full fine-tuning +βœ… **Storage** savings of 65-95% compared to multiple full models +βœ… **Swapping** adapters takes < 1 second +βœ… **Domain specialization** improves accuracy on specific tasks + +### Quick Reference + +```python +# Training +config = TrainingConfig( + use_lora=True, + lora_r=8, + lora_alpha=16.0, + save_adapter_only=True, +) +trainer.train(train_data=examples) + +# Loading +model = GLiNER2.from_pretrained("base-model") +model.load_adapter("./adapter/final") + +# Swapping +model.load_adapter("./other_adapter/final") + +# Unloading +model.unload_adapter() + +# Checking +print(model.has_adapter) +print(model.adapter_config) +``` + +### Next Steps + +1. **Train your first adapter** with domain-specific data +2. **Evaluate performance** on test set +3. **Experiment with hyperparameters** (rank, alpha, target modules) +4. **Deploy multiple adapters** for different use cases +5. **Monitor and iterate** based on real-world performance + +For more information: +- LoRA Paper: https://arxiv.org/abs/2106.09685 +- Implementation: `gliner2/training/lora.py` +- Tests: `tests/test_lora_adapters.py` +- Verification Guide: `LORA_VERIFICATION_TESTS.md` + diff --git a/packages/GLiNER2/tutorial/11-adapter_switching.md b/packages/GLiNER2/tutorial/11-adapter_switching.md new file mode 100644 index 0000000..8a8bf99 --- /dev/null +++ b/packages/GLiNER2/tutorial/11-adapter_switching.md @@ -0,0 +1,201 @@ +# Tutorial 11: LoRA Adapter Switching/Routing + +## Quick Start + +Switch between domain-specific adapters during inference without reloading the base model. + +```python +from gliner2 import GLiNER2 + +# Load base model once +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") + +# Load legal adapter +model.load_adapter("./legal_adapter") +legal_result = model.extract_entities("Apple sued Google", ["company"]) + +# Switch to medical adapter +model.load_adapter("./medical_adapter") +medical_result = model.extract_entities("Patient has diabetes", ["disease"]) + +# Use base model (no adapter) +model.unload_adapter() +base_result = model.extract_entities("Some text", ["entity"]) +``` + +## Basic Usage + +### Loading an Adapter + +```python +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") +model.load_adapter("./path/to/adapter") +``` + +The adapter path should point to a directory containing: +- `adapter_config.json` +- `adapter_weights.safetensors` + +### Checking Adapter Status + +```python +# Check if adapter is loaded +if model.has_adapter: + print("Adapter is loaded") + +# Get adapter configuration +config = model.adapter_config +print(f"LoRA rank: {config.lora_r}") +``` + +### Unloading an Adapter + +```python +# Remove adapter, use base model +model.unload_adapter() +``` + +## Switching Between Adapters + +Adapters automatically swap when you call `load_adapter()`: + +```python +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") + +# Legal domain +model.load_adapter("./legal_adapter") +result1 = model.extract_entities("Apple Inc. filed suit", ["company"]) + +# Medical domain (previous adapter auto-unloaded) +model.load_adapter("./medical_adapter") +result2 = model.extract_entities("Patient has diabetes", ["disease"]) + +# Support domain +model.load_adapter("./support_adapter") +result3 = model.extract_entities("Order #12345 issue", ["order_id"]) +``` + +## Routing by Document Type + +Route documents to appropriate adapters: + +```python +def extract_with_routing(model, text, doc_type, adapters): + """Route document to domain-specific adapter.""" + adapter_path = adapters.get(doc_type) + + if adapter_path: + model.load_adapter(adapter_path) + else: + model.unload_adapter() # Use base model + + # Define entity types per domain + entity_types = { + "legal": ["company", "person", "law"], + "medical": ["disease", "drug", "symptom"], + "support": ["order_id", "customer", "issue"] + } + + return model.extract_entities( + text, + entity_types.get(doc_type, ["entity"]) + ) + +# Setup +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") +adapters = { + "legal": "./legal_adapter", + "medical": "./medical_adapter", + "support": "./support_adapter" +} + +# Use +result = extract_with_routing( + model, + "Apple sued Google", + "legal", + adapters +) +``` + +## Batch Processing by Domain + +Process multiple documents efficiently: + +```python +def process_by_domain(model, documents, adapters): + """Process documents grouped by domain.""" + results = {} + + for domain, docs in documents.items(): + # Load domain adapter + model.load_adapter(adapters[domain]) + + # Process all documents for this domain + results[domain] = [ + model.extract_entities(doc, get_entity_types(domain)) + for doc in docs + ] + + return results + +# Example +documents = { + "legal": ["Apple sued Samsung", "Microsoft acquired LinkedIn"], + "medical": ["Patient has diabetes", "Prescribed Metformin"] +} + +adapters = { + "legal": "./legal_adapter", + "medical": "./medical_adapter" +} + +results = process_by_domain(model, documents, adapters) +``` + +## Simple Router Class + +```python +class AdapterRouter: + """Simple adapter router for multi-domain inference.""" + + def __init__(self, base_model_name, adapters): + self.model = GLiNER2.from_pretrained(base_model_name) + self.adapters = adapters + self.current_domain = None + + def extract(self, text, domain, entity_types): + """Extract entities using domain-specific adapter.""" + # Load adapter if domain changed + if self.current_domain != domain: + adapter_path = self.adapters.get(domain) + if adapter_path: + self.model.load_adapter(adapter_path) + else: + self.model.unload_adapter() + self.current_domain = domain + + return self.model.extract_entities(text, entity_types) + +# Usage +router = AdapterRouter( + "fastino/gliner2-base-v1", + { + "legal": "./legal_adapter", + "medical": "./medical_adapter" + } +) + +result = router.extract("Apple sued Google", "legal", ["company"]) +``` + +## Summary + +- **Load adapter**: `model.load_adapter(path)` +- **Unload adapter**: `model.unload_adapter()` +- **Check status**: `model.has_adapter` +- **Get config**: `model.adapter_config` +- **Auto-swap**: Loading a new adapter automatically unloads the previous one + +For training adapters, see [Tutorial 10: LoRA Adapters](10-lora_adapters.md). + diff --git a/packages/GLiNER2/tutorial/2-ner.md b/packages/GLiNER2/tutorial/2-ner.md new file mode 100644 index 0000000..5ea389f --- /dev/null +++ b/packages/GLiNER2/tutorial/2-ner.md @@ -0,0 +1,372 @@ +# GLiNER2 Entity Extraction Tutorial + +Learn how to extract named entities from text using GLiNER2's flexible entity recognition capabilities. + +## Table of Contents +- [Basic Entity Extraction](#basic-entity-extraction) +- [Entity Extraction with Descriptions](#entity-extraction-with-descriptions) +- [Single vs Multiple Entities](#single-vs-multiple-entities) +- [Custom Thresholds](#custom-thresholds) +- [Advanced Configuration](#advanced-configuration) +- [Domain-Specific Entities](#domain-specific-entities) +- [Best Practices](#best-practices) + +## Basic Entity Extraction + +### Simple Example + +```python +from gliner2 import GLiNER2 + +# Load model +extractor = GLiNER2.from_pretrained("your-model-name") + +# Extract common entities +text = "Apple Inc. CEO Tim Cook announced the new iPhone 15 in Cupertino, California on September 12, 2023." +results = extractor.extract_entities( + text, + ["company", "person", "product", "location", "date"] +) +print(results) +# Output: { +# 'entities': { +# 'company': ['Apple Inc.'], +# 'person': ['Tim Cook'], +# 'product': ['iPhone 15'], +# 'location': ['Cupertino', 'California'], +# 'date': ['September 12, 2023'] +# } +# } +``` + +### Using Schema Builder + +```python +# Same extraction using schema +schema = extractor.create_schema().entities([ + "company", "person", "product", "location", "date" +]) +results = extractor.extract(text, schema) +``` + +## Entity Extraction with Descriptions + +Descriptions significantly improve extraction accuracy by providing context. + +```python +# Medical entity extraction +schema = extractor.create_schema().entities({ + "drug": "Pharmaceutical drugs, medications, or treatment names", + "disease": "Medical conditions, illnesses, or disorders", + "symptom": "Clinical symptoms or patient-reported symptoms", + "dosage": "Medication amounts like '50mg' or '2 tablets daily'", + "organ": "Body parts or organs mentioned in medical context" +}) + +medical_text = """ +Patient was prescribed Metformin 500mg twice daily for Type 2 Diabetes. +She reported fatigue and occasional dizziness. Liver function tests ordered. +""" + +results = extractor.extract(medical_text, schema) +print(results) +# Output: { +# 'entities': { +# 'drug': ['Metformin'], +# 'disease': ['Type 2 Diabetes'], +# 'symptom': ['fatigue', 'dizziness'], +# 'dosage': ['500mg twice daily'], +# 'organ': ['Liver'] +# } +# } +``` + +## Single vs Multiple Entities + +Control whether to extract one or multiple entities per type. + +### Multiple Entities (Default) + +```python +# Default behavior - extracts all matching entities +schema = extractor.create_schema().entities( + ["person", "organization"], + dtype="list" # Default +) + +text = "Bill Gates and Steve Jobs founded Microsoft and Apple respectively." +results = extractor.extract(text, schema) +# Output: { +# 'entities': { +# 'person': ['Bill Gates', 'Steve Jobs'], +# 'organization': ['Microsoft', 'Apple'] +# } +# } +``` + +### Single Entity per Type + +```python +# Extract only the best match per entity type +schema = extractor.create_schema().entities( + ["company", "ceo"], + dtype="str" # Single entity mode +) + +text = "Apple CEO Tim Cook met with Microsoft CEO Satya Nadella." +results = extractor.extract(text, schema) +# Output: { +# 'entities': { +# 'company': 'Apple', # Just one, despite multiple in text +# 'ceo': 'Tim Cook' # Just one +# } +# } +``` + +## Custom Thresholds + +Set confidence thresholds for precise control. + +### Global Threshold + +```python +# High-precision extraction +results = extractor.extract_entities( + text, + ["email", "phone", "address"], + threshold=0.8 # High confidence required +) +``` + +### With Confidence Scores and Character Positions + +You can include confidence scores and character-level start/end positions using `include_confidence` and `include_spans` parameters: + +```python +# Extract entities with confidence scores +text = "Apple Inc. CEO Tim Cook announced iPhone 15 in Cupertino." +results = extractor.extract_entities( + text, + ["company", "person", "product"], + include_confidence=True +) +print(results) +# Output: { +# 'entities': { +# 'company': [ +# {'text': 'Apple Inc.', 'confidence': 0.95}, +# {'text': 'Tim Cook', 'confidence': 0.92} +# ], +# 'product': [ +# {'text': 'iPhone 15', 'confidence': 0.88} +# ] +# } +# } + +# Extract with character positions (spans) +results = extractor.extract_entities( + text, + ["company", "person"], + include_spans=True +) +print(results) +# Output: { +# 'entities': { +# 'company': [ +# {'text': 'Apple Inc.', 'start': 0, 'end': 9} +# ], +# 'person': [ +# {'text': 'Tim Cook', 'start': 15, 'end': 23} +# ] +# } +# } + +# Extract with both confidence and spans +results = extractor.extract_entities( + text, + ["company", "product"], + include_confidence=True, + include_spans=True +) +print(results) +# Output: { +# 'entities': { +# 'company': [ +# {'text': 'Apple Inc.', 'confidence': 0.95, 'start': 0, 'end': 9} +# ], +# 'product': [ +# {'text': 'confidence': 0.88, 'start': 15, 'end': 24} +# ] +# } +# } +``` + +**Note**: When `include_spans` is True, the output format changes: +- **Default** (both False): Returns simple text strings: `['Apple Inc.', 'Tim Cook']` +- **include_confidence=True**: Returns dicts with `{'text': '...', 'confidence': 0.95}` +- **include_spans=True**: Returns dicts with `{'text': '...', 'start': 0, 'end': 9} +- **Both True**: Returns dicts with `{'text': '...', 'confidence': 0.95, 'start': 0, 'end': 9} + +### Per-Entity Thresholds + +```python +# Different thresholds for different entities +schema = extractor.create_schema().entities({ + "email": { + "description": "Email addresses", + "dtype": "list", + "threshold": 0.9 # Very high precision for emails + }, + "phone": { + "description": "Phone numbers including mobile and landline", + "dtype": "list", + "threshold": 0.7 # Moderate threshold + }, + "name": { + "description": "Person names", + "dtype": "list", + "threshold": 0.5 # Lower threshold for names + } +}) + +contact_text = "Contact John Doe at john.doe@email.com or call 555-1234." +results = extractor.extract(contact_text, schema, threshold=0.6) # Default threshold +``` + +## Advanced Configuration + +### Mixed Configuration + +```python +# Combine different entity configurations +schema = extractor.create_schema() + +# Add simple entities +schema.entities(["date", "time", "currency"]) + +# Add entities with descriptions +schema.entities({ + "technical_term": "Technical jargon or specialized terminology", + "metric": "Measurements, KPIs, or quantitative values" +}) + +# Add entities with full configuration +schema.entities({ + "competitor": { + "description": "Competing companies or products", + "dtype": "list", + "threshold": 0.7 + }, + "revenue": { + "description": "Revenue figures or financial amounts", + "dtype": "str", # Only extract one + "threshold": 0.8 + } +}) +``` + +### Incremental Entity Addition + +```python +# Build schema incrementally +schema = extractor.create_schema() + +# Add entities in stages +schema.entities(["person", "location"]) # Basic entities +schema.entities({"company": "Company or organization names"}) # With description +schema.entities({ # With full config + "financial_term": { + "description": "Financial instruments, metrics, or terminology", + "threshold": 0.75 + } +}) +``` + +## Domain-Specific Entities + +### Legal Entities + +```python +legal_schema = extractor.create_schema().entities({ + "party": "Parties involved in legal proceedings (plaintiff, defendant, etc.)", + "law_firm": "Law firm or legal practice names", + "court": "Court names or judicial bodies", + "statute": "Legal statutes, laws, or regulations cited", + "case": "Legal case names or citations", + "judge": "Names of judges or magistrates", + "legal_term": "Legal terminology or concepts" +}) + +legal_text = """ +In the case of Smith v. Jones, Judge Sarah Williams of the Superior Court +ruled that the defendant violated Section 15.2 of the Consumer Protection Act. +The plaintiff was represented by Miller & Associates. +""" +results = extractor.extract(legal_text, legal_schema) +``` + +### Financial Entities + +```python +finance_schema = extractor.create_schema().entities({ + "ticker": "Stock ticker symbols (e.g., AAPL, GOOGL)", + "financial_metric": "Financial metrics like P/E ratio, market cap", + "currency_amount": "Monetary values with currency symbols", + "percentage": "Percentage values (e.g., 5.2%, -3%)", + "financial_org": "Banks, investment firms, financial institutions", + "market_index": "Stock market indices (S&P 500, NASDAQ, etc.)" +}) + +finance_text = """ +AAPL rose 3.5% to $185.50 after beating earnings expectations. +The company's P/E ratio of 28.5 attracted Goldman Sachs analysts. +The NASDAQ composite gained 1.2% for the day. +""" +results = extractor.extract(finance_text, finance_schema) +``` + +### Scientific Entities + +```python +science_schema = extractor.create_schema().entities({ + "chemical": "Chemical compounds or elements", + "organism": "Biological organisms, species names", + "gene": "Gene names or identifiers", + "measurement": "Scientific measurements with units", + "research_method": "Research techniques or methodologies", + "institution": "Universities or research institutions" +}) + +science_text = """ +Researchers at MIT discovered that the BRCA1 gene mutation increases +cancer risk by 70%. The study used CRISPR-Cas9 to modify DNA sequences +in Mus musculus specimens, measuring tumor growth in millimeters. +""" +results = extractor.extract(science_text, science_schema) +``` + +## Best Practices + +### 1. Use Descriptive Entity Names + +```python +# Good - Clear, specific entity types +schema.entities(["drug_name", "medical_device", "procedure_name"]) + +# Less ideal - Too generic +schema.entities(["thing", "item", "stuff"]) +``` + +### 2. Provide Context with Descriptions + +```python +# Good - Clear descriptions +schema.entities({ + "acquisition_company": "Company that is acquiring another company", + "target_company": "Company being acquired", + "acquisition_price": "Purchase price or valuation of acquisition" +}) + +# Less ideal - No context +schema.entities(["company1", "company2", "price"]) +``` \ No newline at end of file diff --git a/packages/GLiNER2/tutorial/3-json_extraction.md b/packages/GLiNER2/tutorial/3-json_extraction.md new file mode 100644 index 0000000..daf6cb3 --- /dev/null +++ b/packages/GLiNER2/tutorial/3-json_extraction.md @@ -0,0 +1,504 @@ +# GLiNER2 JSON Structure Extraction Tutorial + +Learn how to extract complex structured data from text using GLiNER2's hierarchical extraction capabilities. + +## Table of Contents +- [Quick API with extract_json](#quick-api-with-extract_json) +- [Field Types and Specifications](#field-types-and-specifications) +- [Multiple Instances](#multiple-instances) +- [Schema Builder (Multi-Task)](#schema-builder-multi-task) +- [Real-World Examples](#real-world-examples) +- [Best Practices](#best-practices) + +## Quick API with extract_json + +For structure-only extraction, use the `extract_json()` method with the simple dictionary format: + +### Basic Structure Extraction + +```python +from gliner2 import GLiNER2 + +# Load model +extractor = GLiNER2.from_pretrained("your-model-name") + +# Simple product extraction +text = "The MacBook Pro costs $1999 and features M3 chip, 16GB RAM, and 512GB storage." +results = extractor.extract_json( + text, + { + "product": [ + "name::str", + "price", + "features" + ] + } +) +print(results) +# Output: { +# 'product': [{ +# 'name': 'MacBook Pro', +# 'price': ['$1999'], +# 'features': ['M3 chip', '16GB RAM', '512GB storage'] +# }] +# } +``` + +### Contact Information + +```python +text = """ +Contact: John Smith +Email: john@example.com +Phones: 555-1234, 555-5678 +Address: 123 Main St, NYC +""" + +results = extractor.extract_json( + text, + { + "contact": [ + "name::str", + "email::str", + "phone::list", + "address" + ] + } +) +# Output: { +# 'contact': [{ +# 'name': 'John Smith', +# 'email': 'john@example.com', +# 'phone': ['555-1234', '555-5678'], +# 'address': ['123 Main St, NYC'] +# }] +# } +``` + +## Field Types and Specifications + +### Field Specification Format + +Fields support flexible specifications using `::` separators: + +``` +"field_name::type::description" +"field_name::[choice1|choice2|choice3]::type::description" +"field_name::description" # defaults to list type +"field_name" # simple field, defaults to list +``` + +### String vs List Fields + +```python +text = """ +Tech Conference 2024 on June 15th in San Francisco. +Topics include AI, Machine Learning, and Cloud Computing. +Registration fee: $299 for early bird tickets. +""" + +results = extractor.extract_json( + text, + { + "event": [ + "name::str::Event or conference name", + "date::str::Event date", + "location::str", + "topics::list::Conference topics", + "registration_fee::str" + ] + } +) +# Output: { +# 'event': [{ +# 'name': 'Tech Conference 2024', +# 'date': 'June 15th', +# 'location': 'San Francisco', +# 'topics': ['AI', 'Machine Learning', 'Cloud Computing'], +# 'registration_fee': '$299' +# }] +# } +``` + +### Choice Fields (Classification within Structure) + +```python +text = """ +Reservation at Le Bernardin for 4 people on March 15th at 7:30 PM. +We'd prefer outdoor seating. Two guests are vegetarian and one is gluten-free. +""" + +results = extractor.extract_json( + text, + { + "reservation": [ + "restaurant::str::Restaurant name", + "date::str", + "time::str", + "party_size::[1|2|3|4|5|6+]::str::Number of guests", + "seating::[indoor|outdoor|bar]::str::Seating preference", + "dietary::[vegetarian|vegan|gluten-free|none]::list::Dietary restrictions" + ] + } +) +# Output: { +# 'reservation': [{ +# 'restaurant': 'Le Bernardin', +# 'date': 'March 15th', +# 'time': '7:30 PM', +# 'party_size': '4', +# 'seating': 'outdoor', +# 'dietary': ['vegetarian', 'gluten-free'] +# }] +# } +``` + +## Multiple Instances + +GLiNER2 automatically extracts ALL instances of a structure found in text: + +### Multiple Transactions + +```python +text = """ +Recent transactions: +- Jan 5: Starbucks $5.50 (food) +- Jan 5: Uber $23.00 (transport) +- Jan 6: Amazon $156.99 (shopping) +""" + +results = extractor.extract_json( + text, + { + "transaction": [ + "date::str", + "merchant::str", + "amount::str", + "category::[food|transport|shopping|utilities]::str" + ] + } +) +# Output: { +# 'transaction': [ +# {'date': 'Jan 5', 'merchant': 'Starbucks', 'amount': '$5.50', 'category': 'food'}, +# {'date': 'Jan 5', 'merchant': 'Uber', 'amount': '$23.00', 'category': 'transport'}, +# {'date': 'Jan 6', 'merchant': 'Amazon', 'amount': '$156.99', 'category': 'shopping'} +# ] +# } +``` + +### Multiple Hotel Bookings + +```python +text = """ +Alice Brown booked the Hilton Downtown from March 10 to March 12. She selected a double room +for $340 total with breakfast and parking included. + +Robert Taylor reserved The Grand Hotel, April 1 to April 5, suite at $1,200 total. +Amenities include breakfast, wifi, gym, and spa access. +""" + +results = extractor.extract_json( + text, + { + "booking": [ + "guest::str::Guest name", + "hotel::str::Hotel name", + "check_in::str", + "check_out::str", + "room_type::[single|double|suite|deluxe]::str", + "total_price::str", + "amenities::[breakfast|wifi|parking|gym|spa]::list" + ] + } +) +# Output: { +# 'booking': [ +# { +# 'guest': 'Alice Brown', +# 'hotel': 'Hilton Downtown', +# 'check_in': 'March 10', +# 'check_out': 'March 12', +# 'room_type': 'double', +# 'total_price': '$340', +# 'amenities': ['breakfast', 'parking'] +# }, +# { +# 'guest': 'Robert Taylor', +# 'hotel': 'The Grand Hotel', +# 'check_in': 'April 1', +# 'check_out': 'April 5', +# 'room_type': 'suite', +# 'total_price': '$1,200', +# 'amenities': ['breakfast', 'wifi', 'gym', 'spa'] +# } +# ] +# } +``` + +## Schema Builder (Multi-Task) + +Use `create_schema()` only when combining structured extraction with other tasks (entities, classification): + +### Multi-Task Extraction + +```python +# Use schema builder for multi-task scenarios +schema = (extractor.create_schema() + # Extract entities + .entities(["person", "company", "location"]) + + # Classify sentiment + .classification("sentiment", ["positive", "negative", "neutral"]) + + # Extract structured product info + .structure("product") + .field("name", dtype="str") + .field("price", dtype="str") + .field("features", dtype="list") + .field("category", dtype="str", choices=["electronics", "software", "service"]) +) + +text = "Apple CEO Tim Cook announced iPhone 15 for $999 with amazing new features. This is exciting!" +results = extractor.extract(text, schema) +# Output: { +# 'entities': {'person': ['Tim Cook'], 'company': ['Apple'], 'location': []}, +# 'sentiment': 'positive', +# 'product': [{ +# 'name': 'iPhone 15', +# 'price': '$999', +# 'features': ['amazing new features'], +# 'category': 'electronics' +# }] +# } +``` + +### Advanced Configuration + +```python +schema = (extractor.create_schema() + .classification("urgency", ["low", "medium", "high"]) + + .structure("support_ticket") + .field("ticket_id", dtype="str", threshold=0.9) # High precision + .field("customer", dtype="str", description="Customer name") + .field("issue", dtype="str", description="Problem description") + .field("priority", dtype="str", choices=["low", "medium", "high", "urgent"]) + .field("tags", dtype="list", choices=["bug", "feature", "support", "billing"]) +) +``` + +## Examples + +### Financial Transaction Processing + +```python +text = """ +Goldman Sachs processed a $2.5M equity trade for Tesla Inc. on March 15, 2024. +Commission: $1,250. Status: Completed. +""" + +results = extractor.extract_json( + text, + { + "transaction": [ + "broker::str::Financial institution", + "amount::str::Transaction amount", + "security::str::Stock or financial instrument", + "date::str::Transaction date", + "commission::str::Fees charged", + "status::[pending|completed|failed]::str", + "type::[equity|bond|option|future]::str" + ] + } +) +# Output: { +# 'transaction': [{ +# 'broker': 'Goldman Sachs', +# 'amount': '$2.5M', +# 'security': 'Tesla Inc.', +# 'date': 'March 15, 2024', +# 'commission': '$1,250', +# 'status': 'completed', +# 'type': 'equity' +# }] +# } +``` + +### Medical Prescription Extraction + +```python +text = """ +Patient: Sarah Johnson, 34, presented with chest pain. +Prescribed: Lisinopril 10mg daily, Metoprolol 25mg twice daily. +Follow-up scheduled for next Tuesday. +""" + +results = extractor.extract_json( + text, + { + "patient": [ + "name::str::Patient full name", + "age::str::Patient age", + "symptoms::list::Reported symptoms" + ], + "prescription": [ + "medication::str::Drug name", + "dosage::str::Dosage amount", + "frequency::str::How often to take" + ] + } +) +# Output: { +# 'patient': [{ +# 'name': 'Sarah Johnson', +# 'age': '34', +# 'symptoms': ['chest pain'] +# }], +# 'prescription': [ +# {'medication': 'Lisinopril', 'dosage': '10mg', 'frequency': 'daily'}, +# {'medication': 'Metoprolol', 'dosage': '25mg', 'frequency': 'twice daily'} +# ] +# } +``` + +### E-commerce Order Processing + +```python +text = """ +Order #ORD-2024-001 for Alexandra Thompson +Items: Laptop Stand (2x $45.99), Wireless Mouse (1x $29.99), USB Hub (3x $35.50) +Subtotal: $228.46, Tax: $18.28, Total: $246.74 +Status: Processing +""" + +results = extractor.extract_json( + text, + { + "order": [ + "order_id::str::Order number", + "customer::str::Customer name", + "items::list::Product names", + "quantities::list::Item quantities", + "unit_prices::list::Individual prices", + "subtotal::str", + "tax::str", + "total::str", + "status::[pending|processing|shipped|delivered]::str" + ] + } +) +# Output: { +# 'order': [{ +# 'order_id': 'ORD-2024-001', +# 'customer': 'Alexandra Thompson', +# 'items': ['Laptop Stand', 'Wireless Mouse', 'USB Hub'], +# 'quantities': ['2', '1', '3'], +# 'unit_prices': ['$45.99', '$29.99', '$35.50'], +# 'subtotal': '$228.46', +# 'tax': '$18.28', +# 'total': '$246.74', +# 'status': 'processing' +# }] +# } +``` + +## Confidence Scores and Character Positions + +You can include confidence scores and character-level start/end positions for structured extraction: + +```python +# Extract with confidence scores +text = "The MacBook Pro costs $1999 and features M3 chip, 16GB RAM, and 512GB storage." +results = extractor.extract_json( + text, + { + "product": [ + "name::str", + "price", + "features" + ] + }, + include_confidence=True +) +# Output: { +# 'product': [{ +# 'name': {'text': 'MacBook Pro', 'confidence': 0.95}, +# 'price': [{'text': '$1999', 'confidence': 0.92}], +# 'features': [ +# {'text': 'M3 chip', 'confidence': 0.88}, +# {'text': '16GB RAM', 'confidence': 0.90}, +# {'text': '512GB storage', 'confidence': 0.87} +# ] +# }] +# } + +# Extract with character positions (spans) +results = extractor.extract_json( + text, + { + "product": [ + "name::str", + "price" + ] + }, + include_spans=True +) +# Output: { +# 'product': [{ +# 'name': {'text': 'MacBook Pro', 'start': 4, 'end': 15}, +# 'price': [{'text': '$1999', 'start': 22, 'end': 27}] +# }] +# } + +# Extract with both confidence and spans +results = extractor.extract_json( + text, + { + "product": [ + "name::str", + "price", + "features" + ] + }, + include_confidence=True, + include_spans=True +) +# Output: { +# 'product': [{ +# 'name': {'text': 'MacBook Pro', 'confidence': 0.95, 'start': 4, 'end': 15}, +# 'price': [{'text': '$1999', 'confidence': 0.92, 'start': 22, 'end': 27}], +# 'features': [ +# {'text': 'M3 chip', 'confidence': 0.88, 'start': 32, 'end': 39}, +# {'text': '16GB RAM', 'confidence': 0.90, 'start': 41, 'end': 49}, +# {'text': '512GB storage', 'confidence': 0.87, 'start': 55, 'end': 68} +# ] +# }] +# } +``` + +**Note**: When `include_spans` or `include_confidence` is True: +- **String fields** (`dtype="str"`): Return dicts with `{'text': '...', 'confidence': 0.95, 'start': 0, 'end': 5}` (or subset) +- **List fields** (`dtype="list"`): Return lists of dicts, each with text, confidence, and positions +- **Default** (both False): Returns simple strings or lists of strings + +## Best Practices + +### Data Types + +- Use `::str` for single values (IDs, names, amounts) +- Use `::list` or default for multiple values (features, items, tags) +- Use choices `[opt1|opt2|opt3]` for standardized values +- Add descriptions for complex or domain-specific fields + +### Quick Decision Guide + +**Use `extract_json()`** for: +- Structure-only extraction +- Quick data parsing +- Single extraction task + +**Use `create_schema().extract()`** for: +- Multi-task scenarios (entities + structures + classification) +- When you need entities or classification alongside structures +- Complex extraction pipelines \ No newline at end of file diff --git a/packages/GLiNER2/tutorial/4-combined.md b/packages/GLiNER2/tutorial/4-combined.md new file mode 100644 index 0000000..9f61fb6 --- /dev/null +++ b/packages/GLiNER2/tutorial/4-combined.md @@ -0,0 +1,357 @@ +# GLiNER2 Combining Schemas Tutorial + +## Table of Contents +- [Why Combine Schemas](#why-combine-schemas) +- [Basic Combinations](#basic-combinations) +- [Advanced Multi-Task Schemas](#advanced-multi-task-schemas) +- [Real-World Applications](#real-world-applications) + +## Why Combine Schemas + +Combining schemas allows you to: +- Extract multiple types of information in one pass +- Maintain context between different extraction tasks +- Improve efficiency by avoiding multiple model calls +- Build comprehensive information extraction pipelines + +## Basic Combinations + +### Entities + Classification + +```python +from gliner2 import GLiNER2 + +extractor = GLiNER2.from_pretrained("your-model-name") + +# Sentiment analysis with entity extraction +schema = (extractor.create_schema() + .entities(["person", "product", "company"]) + .classification("sentiment", ["positive", "negative", "neutral"]) + .classification("category", ["review", "news", "opinion"]) +) + +text = "Tim Cook announced that Apple's new iPhone is exceeding sales expectations." +results = extractor.extract(text, schema) +# Output: { +# 'entities': { +# 'person': ['Tim Cook'], +# 'product': ['iPhone'], +# 'company': ['Apple'] +# }, +# 'sentiment': 'positive', +# 'category': 'news' +# } +``` + +### Entities + Structures + +```python +schema = (extractor.create_schema() + .entities({ + "person": "Names of people mentioned", + "date": "Dates and time references" + }) + .structure("appointment") + .field("patient", dtype="str") + .field("doctor", dtype="str") + .field("date") + .field("time") + .field("type", dtype="str", choices=["checkup", "followup", "consultation"]) +) + +text = """ +Dr. Sarah Johnson confirmed the appointment with John Smith for +March 15th at 2:30 PM. This will be a follow-up consultation +regarding his previous visit on February 1st. +""" +results = extractor.extract(text, schema) +``` + +### Classification + Structures + +```python +schema = (extractor.create_schema() + .classification("email_type", + ["order_confirmation", "shipping_update", "promotional", "support"]) + .classification("priority", ["urgent", "normal", "low"]) + .structure("order_info") + .field("order_number", dtype="str") + .field("items") + .field("total", dtype="str") + .field("status", dtype="str", + choices=["pending", "processing", "shipped", "delivered"]) +) +``` + +## Advanced Multi-Task Schemas + +### Complete Document Analysis + +```python +# Comprehensive invoice extraction +invoice_schema = (extractor.create_schema() + # Document classification + .classification("document_type", + ["invoice", "credit_note", "purchase_order", "receipt"]) + .classification("payment_status", + ["paid", "unpaid", "partial", "overdue"]) + + # Key entities + .entities({ + "company": "Company names (buyer or seller)", + "person": "Contact person names", + "date": "Important dates", + "amount": "Monetary amounts" + }) + + # Structured information + .structure("invoice_header") + .field("invoice_number", dtype="str") + .field("issue_date", dtype="str") + .field("due_date", dtype="str") + .field("vendor_name", dtype="str") + .field("customer_name", dtype="str") + + .structure("line_item") + .field("description", dtype="str") + .field("quantity") + .field("unit_price") + .field("amount") + .field("tax_rate", dtype="str", choices=["0%", "5%", "10%", "20%"]) + + .structure("payment_info") + .field("method", dtype="str", + choices=["bank_transfer", "credit_card", "check", "cash"]) + .field("terms", description="Payment terms like NET30") + .field("bank_details", dtype="list") +) +``` + +### Customer Feedback Analysis + +```python +feedback_schema = (extractor.create_schema() + # Overall classifications + .classification("sentiment", ["positive", "negative", "neutral", "mixed"]) + .classification("intent", { + "complaint": "Customer expressing dissatisfaction", + "compliment": "Customer expressing satisfaction", + "suggestion": "Customer providing improvement ideas", + "question": "Customer asking for information" + }, multi_label=True) + + # Extract mentioned entities + .entities({ + "product": "Products or services mentioned", + "feature": "Specific features discussed", + "competitor": "Competing products mentioned", + "price_mention": "Price points or cost references" + }) + + # Structured feedback components + .structure("issue") + .field("problem", dtype="str") + .field("severity", dtype="str", choices=["critical", "major", "minor"]) + .field("affected_area", dtype="list") + + .structure("suggestion") + .field("improvement", dtype="str") + .field("benefit", description="Expected benefit of the suggestion") +) +``` + +### News Article Analysis + +```python +news_schema = (extractor.create_schema() + # Article metadata + .classification("category", + ["politics", "business", "technology", "sports", "entertainment"]) + .classification("bias", ["left", "center", "right", "neutral"]) + .classification("factuality", ["fact", "opinion", "analysis", "speculation"]) + + # Key entities + .entities({ + "person": "People mentioned in the article", + "organization": "Companies, agencies, or groups", + "location": "Places, cities, or countries", + "event": "Named events or incidents" + }) + + # Structured content + .structure("quote") + .field("speaker", dtype="str") + .field("statement", dtype="str") + .field("context", description="Context of the quote") + + .structure("claim") + .field("statement", dtype="str") + .field("source", dtype="str") + .field("evidence", dtype="list") +) +``` + +## Real-World Applications + +### E-commerce Product Listing + +```python +product_schema = (extractor.create_schema() + # Listing classification + .classification("condition", ["new", "used", "refurbished", "for_parts"]) + .classification("listing_type", ["buy_now", "auction", "best_offer"]) + + # Extract key entities + .entities({ + "brand": "Product brand or manufacturer", + "model": "Specific model name or number", + "color": "Product colors mentioned", + "size": "Size specifications" + }) + + # Product details + .structure("product") + .field("title", dtype="str") + .field("price", dtype="str") + .field("features", dtype="list") + .field("category", dtype="str") + + # Shipping information + .structure("shipping") + .field("method", dtype="list", + choices=["standard", "express", "overnight", "international"]) + .field("cost", dtype="str") + .field("delivery_time", description="Estimated delivery timeframe") + + # Seller information + .structure("seller") + .field("name", dtype="str") + .field("rating", dtype="str") + .field("location", dtype="str") +) +``` + +### Healthcare Clinical Note + +```python +clinical_schema = (extractor.create_schema() + # Note classification + .classification("visit_type", + ["initial_consultation", "follow_up", "emergency", "routine_checkup"]) + .classification("urgency", ["urgent", "routine", "elective"]) + + # Medical entities + .entities({ + "symptom": "Patient reported symptoms", + "diagnosis": "Medical diagnoses or conditions", + "medication": "Prescribed or mentioned medications", + "procedure": "Medical procedures or tests", + "body_part": "Anatomical references" + }) + + # Patient information + .structure("patient_info") + .field("name", dtype="str") + .field("age", dtype="str") + .field("gender", dtype="str", choices=["male", "female", "other"]) + .field("chief_complaint", dtype="str") + + # Clinical findings + .structure("vital_signs") + .field("blood_pressure", dtype="str") + .field("heart_rate", dtype="str") + .field("temperature", dtype="str") + .field("respiratory_rate", dtype="str") + + # Treatment plan + .structure("prescription") + .field("medication", dtype="str") + .field("dosage", dtype="str") + .field("frequency") + .field("duration") + .field("route", dtype="str", choices=["oral", "IV", "topical", "injection"]) +) +``` + +### Legal Document Analysis + +```python +legal_schema = (extractor.create_schema() + # Document classification + .classification("document_type", + ["contract", "memorandum", "brief", "motion", "order"]) + .classification("jurisdiction", + ["federal", "state", "local", "international"]) + + # Legal entities + .entities({ + "party": "Parties involved (plaintiff, defendant, etc.)", + "attorney": "Legal representatives", + "judge": "Judicial officers", + "statute": "Laws or regulations cited", + "case_citation": "Referenced legal cases" + }) + + # Contract terms + .structure("contract_term") + .field("clause_type", dtype="str", + choices=["payment", "delivery", "warranty", "liability", "termination"]) + .field("obligation", dtype="str") + .field("party_responsible", dtype="str") + .field("deadline") + + # Legal claims + .structure("claim") + .field("type", dtype="str") + .field("plaintiff", dtype="str") + .field("defendant", dtype="str") + .field("amount", dtype="str") + .field("basis", description="Legal basis for the claim") +) +``` + +## Using Confidence Scores and Character Positions with Combined Schemas + +When using combined schemas, `include_confidence` and `include_spans` parameters apply to all extraction types: + +```python +schema = (extractor.create_schema() + .entities(["person", "company"]) + .classification("sentiment", ["positive", "negative", "neutral"]) + .relations(["works_for"]) + .structure("product") + .field("name", dtype="str") + .field("price", dtype="str") +) + +text = "Tim Cook works for Apple. The iPhone 15 costs $999. This is exciting!" +results = extractor.extract( + text, + schema, + include_confidence=True, + include_spans=True +) +# Output: { +# 'entities': { +# 'person': [ +# {'text': 'Tim Cook', 'confidence': 0.95, 'start': 0, 'end': 8} +# ], +# 'company': [ +# {'text': 'Apple', 'confidence': 0.92, 'start': 20, 'end': 25} +# ] +# }, +# 'sentiment': {'label': 'positive', 'confidence': 0.88}, +# 'relation_extraction': { +# 'works_for': [{ +# 'head': {'text': 'Tim Cook', 'confidence': 0.95, 'start': 0, 'end': 8}, +# 'tail': {'text': 'Apple', 'confidence': 0.92, 'start': 20, 'end': 25} +# }] +# }, +# 'product': [{ +# 'name': {'text': 'iPhone 15', 'confidence': 0.90, 'start': 30, 'end': 39}, +# 'price': {'text': '$999', 'confidence': 0.88, 'start': 46, 'end': 51} +# }] +# } +``` + +**Note**: The `include_confidence` and `include_spans` parameters work consistently across all extraction types (entities, classifications, relations, and structures) when using combined schemas. \ No newline at end of file diff --git a/packages/GLiNER2/tutorial/5-validator.md b/packages/GLiNER2/tutorial/5-validator.md new file mode 100644 index 0000000..a21a85f --- /dev/null +++ b/packages/GLiNER2/tutorial/5-validator.md @@ -0,0 +1,112 @@ +# GLiNER2 Regex Validators + +Regex validators filter extracted spans to ensure they match expected patterns, improving extraction quality and reducing false positives. + +## Quick Start + +```python +from gliner2 import GLiNER2, RegexValidator + +extractor = GLiNER2.from_pretrained("your-model") + +# Create validator and apply to field +email_validator = RegexValidator(r"^[\w\.-]+@[\w\.-]+\.\w+$") +schema = (extractor.create_schema() + .structure("contact") + .field("email", dtype="str", validators=[email_validator]) +) +``` + +## RegexValidator Parameters + +- **pattern**: Regex pattern (string or compiled Pattern) +- **mode**: `"full"` (exact match) or `"partial"` (substring match) +- **exclude**: `False` (keep matches) or `True` (exclude matches) +- **flags**: Regex flags like `re.IGNORECASE` (for string patterns only) + +## Examples + +### Email Validation +```python +email_validator = RegexValidator(r"^[\w\.-]+@[\w\.-]+\.\w+$") + +text = "Contact: john@company.com, not-an-email, jane@domain.org" +# Output: ['john@company.com', 'jane@domain.org'] +``` + +### Phone Numbers (US Format) +```python +phone_validator = RegexValidator(r"\(\d{3}\)\s\d{3}-\d{4}", mode="partial") + +text = "Call (555) 123-4567 or 5551234567" +# Output: ['(555) 123-4567'] # Second number filtered out +``` + +### URLs Only +```python +url_validator = RegexValidator(r"^https?://", mode="partial") + +text = "Visit https://example.com or www.site.com" +# Output: ['https://example.com'] # www.site.com filtered out +``` + +### Exclude Test Data +```python +no_test_validator = RegexValidator(r"^(test|demo|sample)", exclude=True, flags=re.IGNORECASE) + +text = "Products: iPhone, Test Phone, Samsung Galaxy" +# Output: ['iPhone', 'Samsung Galaxy'] # Test Phone excluded +``` + +### Length Constraints +```python +length_validator = RegexValidator(r"^.{5,50}$") # 5-50 characters + +text = "Names: Jo, Alexander, A Very Long Name That Exceeds Fifty Characters" +# Output: ['Alexander'] # Others filtered by length +``` + +### Multiple Validators +```python +# All validators must pass +username_validators = [ + RegexValidator(r"^[a-zA-Z0-9_]+$"), # Alphanumeric + underscore + RegexValidator(r"^.{3,20}$"), # 3-20 characters + RegexValidator(r"^(?!admin)", exclude=True, flags=re.IGNORECASE) # No "admin" +] + +schema = (extractor.create_schema() + .structure("user") + .field("username", dtype="str", validators=username_validators) +) + +text = "Users: ab, john_doe, user@domain, admin, valid_user123" +# Output: ['john_doe', 'valid_user123'] +``` + +## Common Patterns + +| Use Case | Pattern | Mode | +|----------|---------|------| +| Email | `r"^[\w\.-]+@[\w\.-]+\.\w+$"` | full | +| Phone (US) | `r"\(\d{3}\)\s\d{3}-\d{4}"` | partial | +| URL | `r"^https?://"` | partial | +| Numbers only | `r"^\d+$"` | full | +| No spaces | `r"^\S+$"` | full | +| Min length | `r"^.{5,}$"` | full | +| Alphanumeric | `r"^[a-zA-Z0-9]+$"` | full | + +## Best Practices + +1. **Use specific patterns** - More specific = fewer false positives +2. **Test your regex** - Validate patterns before deployment +3. **Combine validators** - Chain multiple simple validators +4. **Consider case sensitivity** - Use `re.IGNORECASE` when needed +5. **Start simple** - Begin with basic patterns, refine as needed + +## Performance Notes + +- Validators run after span extraction but before formatting +- Failed validation simply excludes the span (no errors) +- Multiple validators use short-circuit evaluation (stops at first failure) +- Compiled patterns are cached automatically \ No newline at end of file diff --git a/packages/GLiNER2/tutorial/6-relation_extraction.md b/packages/GLiNER2/tutorial/6-relation_extraction.md new file mode 100644 index 0000000..c69aea0 --- /dev/null +++ b/packages/GLiNER2/tutorial/6-relation_extraction.md @@ -0,0 +1,643 @@ +# GLiNER2 Relation Extraction Tutorial + +Learn how to extract relations between entities from text using GLiNER2's relation extraction capabilities. + +## Table of Contents +- [Basic Relation Extraction](#basic-relation-extraction) +- [Multiple Relation Types](#multiple-relation-types) +- [Relation Extraction with Descriptions](#relation-extraction-with-descriptions) +- [Custom Thresholds](#custom-thresholds) +- [Batch Processing](#batch-processing) +- [Combining with Other Tasks](#combining-with-other-tasks) +- [Real-World Examples](#real-world-examples) +- [Best Practices](#best-practices) + +## Basic Relation Extraction + +### Simple Example + +```python +from gliner2 import GLiNER2 + +# Load model +extractor = GLiNER2.from_pretrained("your-model-name") + +# Extract relations +text = "John works for Apple Inc. and lives in San Francisco." +results = extractor.extract_relations( + text, + ["works_for", "lives_in"] +) +print(results) +# Output: { +# 'relation_extraction': { +# 'works_for': [('John', 'Apple Inc.')], +# 'lives_in': [('John', 'San Francisco')] +# } +# } +``` + +### Using Schema Builder + +```python +# Same extraction using schema +schema = extractor.create_schema().relations([ + "works_for", "lives_in" +]) +results = extractor.extract(text, schema) +``` + +### Understanding the Output Format + +Relations are returned as tuples `(source, target)` grouped under the `relation_extraction` key. **All requested relation types are included in the output, even if no relations are found** (they appear as empty lists `[]`): + +```python +text = "Alice manages the Engineering team. Bob reports to Alice." +results = extractor.extract_relations( + text, + ["manages", "reports_to", "founded"] # Note: "founded" not found in text +) +# Output: { +# 'relation_extraction': { +# 'manages': [('Alice', 'Engineering team')], +# 'reports_to': [('Bob', 'Alice')], +# 'founded': [] # Empty list - relation type requested but not found +# } +# } +``` + +This ensures consistent output structure - all requested relation types will always be present in the results, making it easier to process the output programmatically. + +## Multiple Relation Types + +You can extract multiple relation types in a single call: + +```python +text = """ +Sarah founded TechCorp in 2020. She is married to Mike, +who works at Google. TechCorp is located in Seattle. +""" + +results = extractor.extract_relations( + text, + ["founded", "married_to", "works_at", "located_in"] +) +# Output: { +# 'relation_extraction': { +# 'founded': [('Sarah', 'TechCorp')], +# 'married_to': [('Sarah', 'Mike')], +# 'works_at': [('Mike', 'Google')], +# 'located_in': [('TechCorp', 'Seattle')] +# } +# } +``` + +### Multiple Instances per Relation Type + +GLiNER2 automatically extracts all relation instances found in the text: + +```python +text = """ +John works for Microsoft. Mary works for Google. +Bob works for Apple. All three live in California. +""" + +results = extractor.extract_relations( + text, + ["works_for", "lives_in"] +) +# Output: { +# 'relation_extraction': { +# 'works_for': [ +# ('John', 'Microsoft'), +# ('Mary', 'Google'), +# ('Bob', 'Apple') +# ], +# 'lives_in': [ +# ('John', 'California'), +# ('Mary', 'California'), +# ('Bob', 'California') +# ] +# } +# } +``` + +## Relation Extraction with Descriptions + +Providing descriptions helps improve extraction accuracy by clarifying the relation semantics: + +```python +schema = extractor.create_schema().relations({ + "works_for": "Employment relationship where person works at organization", + "founded": "Founding relationship where person created organization", + "acquired": "Acquisition relationship where company bought another company", + "located_in": "Geographic relationship where entity is in a location" +}) + +text = """ +Elon Musk founded SpaceX in 2002. SpaceX is located in Hawthorne, California. +Tesla acquired SolarCity in 2016. Many engineers work for SpaceX. +""" + +results = extractor.extract(text, schema) +``` + +### Advanced Configuration + +```python +schema = extractor.create_schema().relations({ + "works_for": { + "description": "Employment or professional relationship", + "threshold": 0.7 # Higher precision for employment relations + }, + "located_in": { + "description": "Geographic containment relationship", + "threshold": 0.6 # Moderate threshold + }, + "reports_to": { + "description": "Organizational hierarchy relationship", + "threshold": 0.8 # Very high precision + } +}) +``` + +## Custom Thresholds + +### Global Threshold + +```python +# High-precision relation extraction +results = extractor.extract_relations( + text, + ["acquired", "merged_with"], + threshold=0.8 # High confidence required +) +``` + +### Per-Relation Thresholds + +```python +schema = extractor.create_schema().relations({ + "acquired": { + "description": "Company acquisition relationship", + "threshold": 0.9 # Very high precision + }, + "partnered_with": { + "description": "Partnership or collaboration relationship", + "threshold": 0.6 # Moderate threshold + }, + "competes_with": { + "description": "Competitive relationship", + "threshold": 0.5 # Lower threshold for implicit relations + } +}) +``` + +### With Confidence Scores and Character Positions + +You can include confidence scores and character-level start/end positions for relation extractions: + +```python +# Extract relations with confidence scores +text = "John works for Apple Inc. and lives in San Francisco." +results = extractor.extract_relations( + text, + ["works_for", "lives_in"], + include_confidence=True +) +print(results) +# Output: { +# 'relation_extraction': { +# 'works_for': [{ +# 'head': {'text': 'John', 'confidence': 0.95}, +# 'tail': {'text': 'Apple Inc.', 'confidence': 0.92} +# }], +# 'lives_in': [{ +# 'head': {'text': 'John', 'confidence': 0.94}, +# 'tail': {'text': 'San Francisco', 'confidence': 0.91} +# }] +# } +# } + +# Extract with character positions (spans) +results = extractor.extract_relations( + text, + ["works_for", "lives_in"], + include_spans=True +) +print(results) +# Output: { +# 'relation_extraction': { +# 'works_for': [{ +# 'head': {'text': 'John', 'start': 0, 'end': 4}, +# 'tail': {'text': 'Apple Inc.', 'start': 15, 'end': 25} +# }], +# 'lives_in': [{ +# 'head': {'text': 'John', 'start': 0, 'end': 4}, +# 'tail': {'text': 'San Francisco', 'start': 33, 'end': 46} +# }] +# } +# } + +# Extract with both confidence and spans +results = extractor.extract_relations( + text, + ["works_for", "lives_in"], + include_confidence=True, + include_spans=True +) +print(results) +# Output: { +# 'relation_extraction': { +# 'works_for': [{ +# 'head': {'text': 'John', 'confidence': 0.95, 'start': 0, 'end': 4}, +# 'tail': {'text': 'Apple Inc.', 'confidence': 0.92, 'start': 15, 'end': 25} +# }], +# 'lives_in': [{ +# 'head': {'text': 'John', 'confidence': 0.94, 'start': 0, 'end': 4}, +# 'tail': {'text': 'San Francisco', 'confidence': 0.91, 'start': 33, 'end': 46} +# }] +# } +# } +``` + +**Note**: When `include_spans` or `include_confidence` is True, relations are returned as dictionaries with `head` and `tail` keys, each containing the extracted text along with optional confidence scores and character positions. When both are False (default), relations are returned as simple tuples `(head, tail)`. + +## Batch Processing + +Process multiple texts efficiently: + +```python +texts = [ + "John works for Microsoft and lives in Seattle.", + "Sarah founded TechStartup in 2020.", + "Bob reports to Alice at Google." +] + +results = extractor.batch_extract_relations( + texts, + ["works_for", "founded", "reports_to", "lives_in"], + batch_size=8 +) +# Output: [ +# { +# 'relation_extraction': { +# 'works_for': [('John', 'Microsoft')], +# 'lives_in': [('John', 'Seattle')], +# 'founded': [], # Not found in first text +# 'reports_to': [] # Not found in first text +# } +# }, +# { +# 'relation_extraction': { +# 'works_for': [], # Not found in second text +# 'founded': [('Sarah', 'TechStartup')], +# 'reports_to': [], # Not found in second text +# 'lives_in': [] # Not found in second text +# } +# }, +# { +# 'relation_extraction': { +# 'works_for': [('Alice', 'Google')], +# 'reports_to': [('Bob', 'Alice')], +# 'founded': [], # Not found in third text +# 'lives_in': [] # Not found in third text +# } +# } +# ] +``` + +**Note**: All requested relation types appear in each result, even if empty. This ensures consistent structure across all batch results, making it easier to process programmatically. + +## Combining with Other Tasks + +Relation extraction can be combined with entity extraction, classification, and structured extraction: + +### Relations + Entities + +```python +schema = (extractor.create_schema() + .entities(["person", "organization", "location"]) + .relations(["works_for", "located_in"]) +) + +text = "Tim Cook works for Apple Inc., which is located in Cupertino, California." +results = extractor.extract(text, schema) +# Output: { +# 'entities': { +# 'person': ['Tim Cook'], +# 'organization': ['Apple Inc.'], +# 'location': ['Cupertino', 'California'] +# }, +# 'relation_extraction': { +# 'works_for': [('Tim Cook', 'Apple Inc.')], +# 'located_in': [('Apple Inc.', 'Cupertino')] +# } +# } +``` + +### Relations + Classification + Structures + +```python +schema = (extractor.create_schema() + .classification("document_type", ["news", "report", "announcement"]) + .entities(["person", "company"]) + .relations(["works_for", "acquired"]) + .structure("event") + .field("date", dtype="str") + .field("description", dtype="str") +) + +text = """ +BREAKING: Microsoft announced today that it acquired GitHub. +Satya Nadella, CEO of Microsoft, confirmed the deal. +The acquisition was finalized on October 26, 2018. +""" + +results = extractor.extract(text, schema) +``` + +## Real-World Examples + +### Organizational Relationships + +```python +org_schema = extractor.create_schema().relations({ + "reports_to": "Direct reporting relationship in organizational hierarchy", + "manages": "Management relationship where person manages team/department", + "works_for": "Employment relationship", + "founded": "Founding relationship", + "acquired": "Company acquisition relationship" +}) + +text = """ +Sundar Pichai is the CEO of Google. He reports to the board of directors. +Google acquired YouTube in 2006. Many engineers work for Google. +""" + +results = extractor.extract(text, org_schema) +# Output: { +# 'relation_extraction': { +# 'reports_to': [('Sundar Pichai', 'board of directors')], +# 'works_for': [('engineers', 'Google')], +# 'acquired': [('Google', 'YouTube')] +# } +# } +``` + +### Medical Relationships + +```python +medical_schema = extractor.create_schema().relations({ + "treats": "Medical treatment relationship between doctor and patient", + "prescribed_for": "Prescription relationship between medication and condition", + "causes": "Causal relationship between condition and symptom", + "located_in": "Anatomical location relationship" +}) + +text = """ +Dr. Smith treats patients with diabetes. Metformin is prescribed for Type 2 Diabetes. +High blood sugar causes frequent urination. The pancreas is located in the abdomen. +""" + +results = extractor.extract(text, medical_schema) +``` + +### Financial Relationships + +```python +finance_schema = extractor.create_schema().relations({ + "invested_in": "Investment relationship between investor and company", + "acquired": "Company acquisition relationship", + "merged_with": "Merger relationship between companies", + "owns": "Ownership relationship" +}) + +text = """ +SoftBank invested in Uber in 2018. Microsoft acquired LinkedIn in 2016. +Disney merged with 21st Century Fox. Berkshire Hathaway owns Geico. +""" + +results = extractor.extract(text, finance_schema) +``` + +### Geographic Relationships + +```python +geo_schema = extractor.create_schema().relations({ + "located_in": "Geographic containment (city in country, etc.)", + "borders": "Geographic adjacency relationship", + "capital_of": "Capital city relationship", + "flows_through": "River or waterway relationship" +}) + +text = """ +Paris is the capital of France. France borders Germany and Spain. +The Seine flows through Paris. Paris is located in France. +""" + +results = extractor.extract(text, geo_schema) +``` + +### Family Relationships + +```python +family_schema = extractor.create_schema().relations({ + "married_to": "Marriage relationship", + "parent_of": "Parent-child relationship", + "sibling_of": "Sibling relationship", + "related_to": "General family relationship" +}) + +text = """ +John is married to Mary. They are parents of two children: Alice and Bob. +Alice and Bob are siblings. Mary is related to her sister Sarah. +""" + +results = extractor.extract(text, family_schema) +``` + +### Academic Relationships + +```python +academic_schema = extractor.create_schema().relations({ + "authored": "Publication relationship between author and paper", + "cited": "Citation relationship between papers", + "supervised": "Academic supervision relationship", + "affiliated_with": "Institutional affiliation relationship" +}) + +text = """ +Dr. Johnson authored the paper on machine learning. The paper cited +previous work by Dr. Smith. Dr. Johnson supervises graduate students +at MIT, where she is affiliated with the Computer Science department. +""" + +results = extractor.extract(text, academic_schema) +``` + +## Best Practices + +### 1. Use Clear, Specific Relation Names + +```python +# Good - Clear and specific +schema.relations(["works_for", "reports_to", "manages"]) + +# Less ideal - Too generic +schema.relations(["related", "connected", "linked"]) +``` + +### 2. Provide Descriptions for Ambiguous Relations + +```python +# Good - Clear descriptions +schema.relations({ + "works_for": "Employment relationship where person works at organization", + "consulted_for": "Consulting relationship where person provides services to organization" +}) + +# Less ideal - No context +schema.relations(["works_for", "consulted_for"]) +``` + +### 3. Set Appropriate Thresholds + +```python +# High precision for critical relations +schema.relations({ + "acquired": { + "description": "Company acquisition", + "threshold": 0.9 # Very high precision + }, + "partnered_with": { + "description": "Partnership relationship", + "threshold": 0.6 # Moderate threshold + } +}) +``` + +### 4. Combine with Entity Extraction + +```python +# Extract both entities and relations for better context +schema = (extractor.create_schema() + .entities(["person", "organization"]) + .relations(["works_for", "founded"]) +) +``` + +### 5. Use Batch Processing for Multiple Texts + +```python +# Efficient batch processing +results = extractor.batch_extract_relations( + texts, + relation_types, + batch_size=8 # Adjust based on your hardware +) +``` + +### 6. Handle Multiple Instances + +```python +# GLiNER2 automatically extracts all instances +text = "John works for Apple. Mary works for Google. Bob works for Microsoft." +results = extractor.extract_relations(text, ["works_for"]) +# Returns all three work relationships +``` + +### 7. Handle Empty Relations + +All requested relation types are always included in the output, even if empty: + +```python +results = extractor.extract_relations( + "John works for Microsoft.", + ["works_for", "founded", "acquired"] +) +# Output: { +# 'relation_extraction': { +# 'works_for': [('John', 'Microsoft')], +# 'founded': [], # Empty - not found in text +# 'acquired': [] # Empty - not found in text +# } +# } + +# This makes it easy to check for relations programmatically: +for rel_type, rels in results['relation_extraction'].items(): + if rels: # Non-empty + print(f"Found {len(rels)} {rel_type} relations") + else: # Empty + print(f"No {rel_type} relations found") +``` + +### 7. Validate Relation Direction + +Relations are directional tuples `(source, target)`: +- `works_for`: (person, organization) +- `located_in`: (entity, location) +- `reports_to`: (subordinate, manager) +- `manages`: (manager, team) + +Make sure your relation names match the expected direction. + +## Common Use Cases + +### Knowledge Graph Construction + +```python +# Extract entities and relations for knowledge graph +schema = (extractor.create_schema() + .entities(["person", "organization", "location", "product"]) + .relations([ + "works_for", "founded", "located_in", "created", + "acquired", "partnered_with" + ]) +) + +# Process documents to build knowledge graph +documents = [...] # Your documents +all_relations = [] +all_entities = [] + +for doc in documents: + results = extractor.extract(doc, schema) + all_relations.append(results.get("relation_extraction", {})) + all_entities.append(results.get("entities", {})) +``` + +### Relationship Analysis + +```python +# Analyze organizational structures +org_texts = [...] # Organizational documents +results = extractor.batch_extract_relations( + org_texts, + ["reports_to", "manages", "works_for", "collaborates_with"], + batch_size=8 +) + +# Analyze relationship patterns +for result in results: + relations = result.get("relation_extraction", {}) + # Process relations for analysis +``` + +### Document Understanding + +```python +# Comprehensive document understanding +schema = (extractor.create_schema() + .classification("document_type", ["contract", "report", "email"]) + .entities(["person", "organization", "date", "amount"]) + .relations(["signed_by", "involves", "dated", "worth"]) + .structure("contract_term") + .field("term", dtype="str") + .field("value", dtype="str") +) + +# Extract all information types in one pass +results = extractor.extract(document_text, schema) +``` + diff --git a/packages/GLiNER2/tutorial/7-api.md b/packages/GLiNER2/tutorial/7-api.md new file mode 100644 index 0000000..18c9e62 --- /dev/null +++ b/packages/GLiNER2/tutorial/7-api.md @@ -0,0 +1,514 @@ +# GLiNER2 API Extractor + +Use GLiNER2 through a cloud API without loading models locally. Perfect for production deployments, low-memory environments, or when you need instant access without GPU setup. + +## Table of Contents +- [Getting Started](#getting-started) +- [Basic Usage](#basic-usage) +- [Entity Extraction](#entity-extraction) +- [Text Classification](#text-classification) +- [Structured Extraction](#structured-extraction) +- [Relation Extraction](#relation-extraction) +- [Combined Schemas](#combined-schemas) +- [Batch Processing](#batch-processing) +- [Confidence Scores](#confidence-scores) +- [Error Handling](#error-handling) +- [API vs Local](#api-vs-local) + +## Getting Started + +### Get Your API Key + +1. Visit [gliner.pioneer.ai](https://gliner.pioneer.ai) +2. Sign up or log in to your account +3. Navigate to API Keys section +4. Generate a new API key + +### Installation + +```bash +pip install gliner2 +``` + +### Set Your API Key + +**Option 1: Environment Variable (Recommended)** +```bash +export PIONEER_API_KEY="your-api-key-here" +``` + +**Option 2: Pass Directly** +```python +extractor = GLiNER2.from_api(api_key="your-api-key-here") +``` + +## Basic Usage + +```python +from gliner2 import GLiNER2 + +# Load from API (uses PIONEER_API_KEY environment variable) +extractor = GLiNER2.from_api() + +# Use exactly like the local model! +results = extractor.extract_entities( + "Apple CEO Tim Cook announced the iPhone 15 in Cupertino.", + ["company", "person", "product", "location"] +) +print(results) +# Output: { +# 'entities': { +# 'company': ['Apple'], +# 'person': ['Tim Cook'], +# 'product': ['iPhone 15'], +# 'location': ['Cupertino'] +# } +# } +``` + +## Entity Extraction + +### Simple Extraction + +```python +extractor = GLiNER2.from_api() + +text = "Elon Musk founded SpaceX in 2002 and Tesla in 2003." +results = extractor.extract_entities( + text, + ["person", "company", "date"] +) +# Output: { +# 'entities': { +# 'person': ['Elon Musk'], +# 'company': ['SpaceX', 'Tesla'], +# 'date': ['2002', '2003'] +# } +# } +``` + +### With Confidence Scores and Character Positions + +You can include confidence scores and character-level start/end positions using `include_confidence` and `include_spans`: + +```python +# With confidence only +results = extractor.extract_entities( + "Microsoft acquired LinkedIn for $26.2 billion.", + ["company", "price"], + include_confidence=True +) +# Output: { +# 'entities': { +# 'company': [ +# {'text': 'Microsoft', 'confidence': 0.98}, +# {'text': 'LinkedIn', 'confidence': 0.97} +# ], +# 'price': [ +# {'text': '$26.2 billion', 'confidence': 0.95} +# ] +# } +# } + +# With character positions (spans) only +results = extractor.extract_entities( + "Microsoft acquired LinkedIn.", + ["company"], + include_spans=True +) +# Output: { +# 'entities': { +# 'company': [ +# {'text': 'Microsoft', 'start': 0, 'end': 9}, +# {'text': 'LinkedIn', 'start': 18, 'end': 26} +# ] +# } +# } + +# With both confidence and spans +results = extractor.extract_entities( + "Microsoft acquired LinkedIn for $26.2 billion.", + ["company", "price"], + include_confidence=True, + include_spans=True +) +# Output: { +# 'entities': { +# 'company': [ +# {'text': 'Microsoft', 'confidence': 0.98, 'start': 0, 'end': 9}, +# {'text': 'LinkedIn', 'confidence': 0.97, 'start': 18, 'end': 26} +# ], +# 'price': [ +# {'text': '$26.2 billion', 'confidence': 0.95, 'start': 32, 'end': 45} +# ] +# } +# } +``` + +### Custom Threshold + +```python +# Only return high-confidence extractions +results = extractor.extract_entities( + text, + ["person", "company"], + threshold=0.8 # Minimum 80% confidence +) +``` + +## Text Classification + +### Single-Label Classification + +```python +extractor = GLiNER2.from_api() + +text = "I absolutely love this product! It exceeded all my expectations." +results = extractor.classify_text( + text, + {"sentiment": ["positive", "negative", "neutral"]} +) +# Output: {'sentiment': {'category': 'positive'}} +``` + +### Multi-Task Classification + +```python +text = "Breaking: Major earthquake hits coastal city. Rescue teams deployed." +results = extractor.classify_text( + text, + { + "category": ["politics", "sports", "technology", "disaster", "business"], + "urgency": ["low", "medium", "high"] + } +) +# Output: {'category': 'disaster', 'urgency': 'high'} +``` + +## Structured Extraction + +### Contact Information + +```python +extractor = GLiNER2.from_api() + +text = """ +Contact John Smith at john.smith@email.com or call +1-555-123-4567. +He works as a Senior Engineer at TechCorp Inc. +""" + +results = extractor.extract_json( + text, + { + "contact": [ + "name::str::Full name of the person", + "email::str::Email address", + "phone::str::Phone number", + "job_title::str::Professional title", + "company::str::Company name" + ] + } +) +# Output: { +# 'contact': [{ +# 'name': 'John Smith', +# 'email': 'john.smith@email.com', +# 'phone': '+1-555-123-4567', +# 'job_title': 'Senior Engineer', +# 'company': 'TechCorp Inc.' +# }] +# } +``` + +### Product Information + +```python +text = "iPhone 15 Pro Max - $1199, 256GB storage, Natural Titanium color" + +results = extractor.extract_json( + text, + { + "product": [ + "name::str", + "price::str", + "storage::str", + "color::str" + ] + } +) +# Output: { +# 'product': [{ +# 'name': 'iPhone 15 Pro Max', +# 'price': '$1199', +# 'storage': '256GB', +# 'color': 'Natural Titanium' +# }] +# } +``` + +## Relation Extraction + +Extract relationships between entities as directional tuples (source, target). + +### Basic Relation Extraction + +```python +extractor = GLiNER2.from_api() + +text = "John works for Apple Inc. and lives in San Francisco. Apple Inc. is located in Cupertino." +results = extractor.extract_relations( + text, + ["works_for", "lives_in", "located_in"] +) +# Output: { +# 'relation_extraction': { +# 'works_for': [('John', 'Apple Inc.')], +# 'lives_in': [('John', 'San Francisco')], +# 'located_in': [('Apple Inc.', 'Cupertino')] +# } +# } +``` + +### With Descriptions + +```python +text = "Elon Musk founded SpaceX in 2002. SpaceX is located in Hawthorne, California." + +schema = extractor.create_schema().relations({ + "founded": "Founding relationship where person created organization", + "located_in": "Geographic relationship where entity is in a location" +}) + +results = extractor.extract(text, schema) +# Output: { +# 'relation_extraction': { +# 'founded': [('Elon Musk', 'SpaceX')], +# 'located_in': [('SpaceX', 'Hawthorne, California')] +# } +# } +``` + +### Batch Relation Extraction + +```python +texts = [ + "John works for Microsoft and lives in Seattle.", + "Sarah founded TechStartup in 2020.", + "Bob reports to Alice at Google." +] + +results = extractor.batch_extract_relations( + texts, + ["works_for", "founded", "reports_to", "lives_in"] +) +# Returns list of relation extraction results for each text +``` + +## Combined Schemas + +Combine entities, classification, relations, and structured extraction in a single call. + +```python +extractor = GLiNER2.from_api() + +text = """ +Tech Review: The new MacBook Pro M3 is absolutely fantastic! Apple has outdone themselves. +I tested it in San Francisco last week. Tim Cook works for Apple, which is located in Cupertino. +Highly recommended for developers. Rating: 5 out of 5 stars. +""" + +schema = (extractor.create_schema() + .entities(["company", "product", "location", "person"]) + .classification("sentiment", ["positive", "negative", "neutral"]) + .relations(["works_for", "located_in"]) + .structure("review") + .field("product_name", dtype="str") + .field("rating", dtype="str") + .field("recommendation", dtype="str") +) + +results = extractor.extract(text, schema) +# Output: { +# 'entities': { +# 'company': ['Apple'], +# 'product': ['MacBook Pro M3'], +# 'location': ['San Francisco', 'Cupertino'], +# 'person': ['Tim Cook'] +# }, +# 'sentiment': 'positive', +# 'relation_extraction': { +# 'works_for': [('Tim Cook', 'Apple')], +# 'located_in': [('Apple', 'Cupertino')] +# }, +# 'review': [{ +# 'product_name': 'MacBook Pro M3', +# 'rating': '5 out of 5 stars', +# 'recommendation': 'Highly recommended for developers' +# }] +# } +``` + +## Batch Processing + +Process multiple texts efficiently in a single API call. + +```python +extractor = GLiNER2.from_api() + +texts = [ + "Google's Sundar Pichai unveiled Gemini AI in Mountain View.", + "Microsoft CEO Satya Nadella announced Copilot at Build 2023.", + "Amazon's Andy Jassy revealed new AWS services in Seattle." +] + +results = extractor.batch_extract_entities( + texts, + ["company", "person", "product", "location"] +) + +for i, result in enumerate(results): + print(f"Text {i+1}: {result}") +``` + +## Confidence Scores and Character Positions + +### Entity Extraction with Confidence + +```python +# Include confidence scores +results = extractor.extract_entities( + "Apple released the iPhone 15 in September 2023.", + ["company", "product", "date"], + include_confidence=True +) +# Each entity includes: {'text': '...', 'confidence': 0.95} +``` + +### Entity Extraction with Character Positions + +```python +# Include character-level start/end positions +results = extractor.extract_entities( + "Apple released the iPhone 15.", + ["company", "product"], + include_spans=True +) +# Each entity includes: {'text': '...', 'start': 0, 'end': 5} +``` + +### Both Confidence and Positions + +```python +# Include both confidence and character positions +results = extractor.extract_entities( + "Apple released the iPhone 15 in September 2023.", + ["company", "product", "date"], + include_confidence=True, + include_spans=True +) +# Each entity includes: {'text': '...', 'confidence': 0.95, 'start': 0, 'end': 5} +``` + +### Raw Results (Advanced) + +For full control over the extraction data: + +```python +results = extractor.extract_entities( + "Apple CEO Tim Cook announced new products.", + ["company", "person"], + format_results=False, # Get raw extraction data + include_confidence=True, + include_spans=True +) +# Returns tuples: (text, confidence, start_char, end_char) +``` + +## Error Handling + +```python +from gliner2 import GLiNER2, GLiNER2APIError, AuthenticationError, ValidationError + +try: + extractor = GLiNER2.from_api() + results = extractor.extract_entities(text, entity_types) + +except AuthenticationError: + print("Invalid API key. Check your PIONEER_API_KEY.") + +except ValidationError as e: + print(f"Invalid request: {e}") + +except GLiNER2APIError as e: + print(f"API error: {e}") +``` + +### Connection Settings + +```python +extractor = GLiNER2.from_api( + api_key="your-key", + timeout=60.0, # Request timeout (seconds) + max_retries=5 # Retry failed requests +) +``` + +## API vs Local + +| Feature | API (`from_api()`) | Local (`from_pretrained()`) | +|---------|-------------------|----------------------------| +| Setup | Just API key | GPU/CPU + model download | +| Memory | ~0 MB | 2-8 GB+ | +| Latency | Network dependent | Faster for single texts | +| Batch | Optimized | Optimized | +| Cost | Per request | Free after setup | +| Offline | ❌ | βœ… | +| RegexValidator | ❌ | βœ… | + +### When to Use API + +- Production deployments without GPU +- Serverless functions (AWS Lambda, etc.) +- Quick prototyping +- Low-memory environments +- Mobile/edge applications + +### When to Use Local + +- High-volume processing +- Offline requirements +- Sensitive data (no network transfer) +- Need for RegexValidator +- Cost optimization at scale + +## Seamless Switching + +The API mirrors the local interface exactly, making switching trivial: + +```python +# Development: Use API for quick iteration +extractor = GLiNER2.from_api() + +# Production: Switch to local if needed +# extractor = GLiNER2.from_pretrained("your-model") + +# Same code works with both! +results = extractor.extract_entities(text, entity_types) +``` + +## Limitations + +The API currently does not support: + +1. **RegexValidator** - Use local model for regex-based filtering +2. **Multi-schema batch** - Different schemas per text in batch (works but slower) +3. **Custom models** - API uses the default GLiNER2 model + +## Best Practices + +1. **Store API key securely** - Use environment variables, not hardcoded strings +2. **Handle errors gracefully** - Network issues can occur +3. **Use batch processing** - More efficient than individual calls +4. **Set appropriate timeouts** - Increase for large texts +5. **Cache results** - Avoid redundant API calls for same content + diff --git a/packages/GLiNER2/tutorial/8-train_data.md b/packages/GLiNER2/tutorial/8-train_data.md new file mode 100644 index 0000000..ee0656f --- /dev/null +++ b/packages/GLiNER2/tutorial/8-train_data.md @@ -0,0 +1,630 @@ +# GLiNER2 Training Dataset Formats + +GLiNER2 uses JSONL format where each line contains an `input` and `output` field (or alternatively `text` and `schema`). The `input`/`text` is the text to process, and the `output`/`schema` is the schema with labels/annotations. + +## Quick Format Reference + +### General Structure + +**Primary Format**: +```jsonl +{"input": "text to process", "output": {"schema_definition": "with_annotations"}} +``` + +**Alternative Format** (also supported): +```jsonl +{"text": "text to process", "schema": {"schema_definition": "with_annotations"}} +``` + +Both formats are equivalent - use whichever is more convenient for your workflow. + +### Valid Output Schema Keys + +| Key | Type | Required | Description | +|-----|------|----------|-------------| +| `entities` | `dict[str, list[str]]` | No | Entity type β†’ list of entity mentions | +| `entity_descriptions` | `dict[str, str]` | No | Entity type β†’ description | +| `classifications` | `list[dict]` | No | List of classification tasks | +| `json_structures` | `list[dict]` | No | List of structured data extractions | +| `json_descriptions` | `dict[str, dict[str, str]]` | No | Parent β†’ field β†’ description | +| `relations` | `list[dict]` | No | List of relation extractions | + +### Classification Task Fields + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `task` | `str` | Yes | Task identifier | +| `labels` | `list[str]` | Yes | Available label options | +| `true_label` | `list[str]` or `str` | Yes | Correct label(s) | +| `multi_label` | `bool` | No | Enable multi-label classification | +| `prompt` | `str` | No | Custom prompt for the task | +| `examples` | `list[list[str]]` or `list[tuple[str, str]]` | No | Few-shot examples as [[input, output], ...] pairs. Internally converted to list of lists. | +| `label_descriptions` | `dict[str, str]` | No | Label β†’ description mapping | + +### Entity Fields Format + +Entities use a simple dictionary where keys are entity types and values are lists of mentions: + +| Component | Type | Required | Description | +|-----------|------|----------|-------------| +| Entity type (key) | `str` | Yes | Name of the entity type (e.g., "person", "location") | +| Entity mentions (value) | `list[str]` | Yes | List of entity text spans found in input | + +**Format**: `{"entity_type": ["mention1", "mention2", ...]}` + +### JSON Structure Fields Format + +Each structure is a dictionary with a parent name as key and field definitions as value: + +| Component | Type | Required | Description | +|-----------|------|----------|-------------| +| Parent name (key) | `str` | Yes | Name of the structure (e.g., "product", "contact") | +| Fields (value) | `dict` | Yes | Field name β†’ field value mappings | +| Field value | `str` or `list[str]` or `dict` | Yes | String, list of strings, or choice dict | +| Choice dict | `dict` with `value` and `choices` | No | For classification-style fields | + +**Format**: `[{"parent": {"field1": "value", "field2": ["list", "values"]}}]` + +**Multiple Instances**: When the same parent appears multiple times, each instance is a separate dict in the list: +```jsonl +[{"hotel": {"name": "Hotel A", ...}}, {"hotel": {"name": "Hotel B", ...}}] +``` + +### Relation Fields Format + +Relations use flexible field structures - you can use ANY field names (not just "head" and "tail"): + +| Component | Type | Required | Description | +|-----------|------|----------|-------------| +| Relation name (key) | `str` | Yes | Name of the relation type (e.g., "works_for") | +| Fields (value) | `dict` | Yes | Field name β†’ field value mappings | +| Field value | `str` or `list[str]` | Yes | String or list of strings | + +**Standard Format**: `[{"relation_name": {"head": "entity1", "tail": "entity2"}}]` + +**⚠️ Critical Constraint**: For a given relation type, the **first occurrence** defines the field structure: +- The first instance of "works_for" determines what fields ALL "works_for" instances must have +- All subsequent instances of the same relation type must use the same field names +- Different relation types can have different field structures +- **This consistency is enforced during validation** - inconsistent field structures will raise a `ValidationError` + +**Example**: If first "works_for" has `{"head": "...", "tail": "..."}`, all other "works_for" instances must also have "head" and "tail" fields. + +**Validation**: The `TrainingDataset.validate_relation_consistency()` method checks that all relation types have consistent field structures across the entire dataset. + +--- + +## Alternative Input Formats + +The training data loader supports multiple input formats: + +1. **JSONL files**: `{"input": "...", "output": {...}}` or `{"text": "...", "schema": {...}}` +2. **Python API**: Use `InputExample` and `TrainingDataset` classes from `gliner2.training.data` +3. **Dict lists**: List of dictionaries in the same format as JSONL + +All formats are automatically detected and converted to the internal format. See `gliner2.training.data.DataLoader_Factory` for details. + +--- + +## 1. Classification Tasks + +### Basic Single-Label Classification + +```jsonl +{"input": "This movie is absolutely fantastic! I loved every minute of it.", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}]}} +{"input": "The service at this restaurant was terrible and the food was cold.", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["negative"]}]}} +{"input": "The weather today is okay, nothing special.", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["neutral"]}]}} +``` + +### Multi-label Classification + +```jsonl +{"input": "This smartphone has an amazing camera but the battery life is poor.", "output": {"classifications": [{"task": "product_aspects", "labels": ["camera", "battery", "screen", "performance", "design"], "true_label": ["camera", "battery"], "multi_label": true}]}} +{"input": "Great performance and beautiful design!", "output": {"classifications": [{"task": "product_aspects", "labels": ["camera", "battery", "screen", "performance", "design"], "true_label": ["performance", "design"], "multi_label": true}]}} +``` + +### Classification with Label Descriptions + +```jsonl +{"input": "Breaking: New AI model achieves human-level performance on reasoning tasks.", "output": {"classifications": [{"task": "news_category", "labels": ["technology", "politics", "sports", "entertainment"], "true_label": ["technology"], "label_descriptions": {"technology": "Articles about computers, AI, software, and tech innovations", "politics": "Government, elections, and political news", "sports": "Athletic events, teams, and competitions", "entertainment": "Movies, music, celebrities, and entertainment news"}}]}} +``` + +### Classification with Custom Prompts + +```jsonl +{"input": "The patient shows signs of improvement after treatment.", "output": {"classifications": [{"task": "medical_assessment", "labels": ["improving", "stable", "declining", "critical"], "true_label": ["improving"], "prompt": "Assess the patient's medical condition based on the clinical notes."}]}} +``` + +### Classification with Few-Shot Examples + +Few-shot examples are provided as a list of `[input, output]` pairs. Each example is a list/tuple with exactly 2 elements: + +```jsonl +{"input": "This service exceeded all my expectations!", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"], "examples": [["Great product, highly recommend!", "positive"], ["Terrible experience, very disappointed.", "negative"], ["It's okay, nothing special.", "neutral"]]}]}} +``` + +**Format**: `"examples": [[input_text, output_label], [input_text, output_label], ...]` + +Each example pair must have exactly 2 elements: the input text and the corresponding label. + +### Classification with Both Examples and Descriptions + +```jsonl +{"input": "The algorithm demonstrates linear time complexity.", "output": {"classifications": [{"task": "complexity", "labels": ["constant", "linear", "quadratic", "exponential"], "true_label": ["linear"], "examples": [["O(1) lookup time", "constant"], ["O(n) iteration", "linear"]], "label_descriptions": {"constant": "O(1) - fixed time regardless of input size", "linear": "O(n) - time scales linearly with input", "quadratic": "O(nΒ²) - nested iterations", "exponential": "O(2ⁿ) - recursive branching"}}]}} +``` + +### Multiple Classification Tasks + +```jsonl +{"input": "Exciting new smartphone with innovative features!", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}, {"task": "category", "labels": ["technology", "sports", "politics", "entertainment"], "true_label": ["technology"]}]}} +``` + +### true_label: String vs List Format + +Both formats are supported - use list for consistency or string for brevity: + +```jsonl +{"input": "Sample text A", "output": {"classifications": [{"task": "label", "labels": ["a", "b"], "true_label": ["a"]}]}} +{"input": "Sample text B", "output": {"classifications": [{"task": "label", "labels": ["a", "b"], "true_label": "b"}]}} +{"input": "This is great!", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": "positive"}]}} +``` + +**Note**: +- String format (`"true_label": "positive"`) and list format (`"true_label": ["positive"]`) are both valid for single-label classification +- Internally, string values are automatically converted to lists (`["positive"]`) +- For multi-label classification, always use list format: `"true_label": ["label1", "label2"]` + +--- + +## 2. Named Entity Recognition (NER) + +### Basic NER + +```jsonl +{"input": "John Smith works at OpenAI in San Francisco and will visit London next month.", "output": {"entities": {"person": ["John Smith"], "organization": ["OpenAI"], "location": ["San Francisco", "London"]}}} +{"input": "Apple Inc. CEO Tim Cook announced the iPhone 15 release date.", "output": {"entities": {"person": ["Tim Cook"], "organization": ["Apple Inc."], "product": ["iPhone 15"]}}} +{"input": "The meeting on January 15, 2024 will be held at Microsoft headquarters.", "output": {"entities": {"date": ["January 15, 2024"], "organization": ["Microsoft"]}}} +``` + +### NER with Entity Descriptions + +```jsonl +{"input": "Dr. Sarah Johnson prescribed Metformin 500mg twice daily for diabetes treatment.", "output": {"entities": {"person": ["Dr. Sarah Johnson"], "medication": ["Metformin"], "dosage": ["500mg"], "condition": ["diabetes"]}, "entity_descriptions": {"person": "Names of people mentioned in the text", "medication": "Names of drugs or pharmaceutical products", "dosage": "Specific amounts or dosages of medications", "condition": "Medical conditions or diseases"}}} +``` + +### NER with Multiple Instances of Same Entity Type + +```jsonl +{"input": "Alice, Bob, and Charlie attended the meeting with David.", "output": {"entities": {"person": ["Alice", "Bob", "Charlie", "David"]}}} +``` + +### NER with Empty Entity Types + +```jsonl +{"input": "The conference will be held next week.", "output": {"entities": {"person": [], "organization": [], "location": []}}} +``` + +### Partial NER (Some Entity Types Present) + +```jsonl +{"input": "Microsoft announced new features.", "output": {"entities": {"organization": ["Microsoft"], "person": []}}} +``` + +--- + +## 3. JSON Structure Extraction + +### Basic Structure with String Fields + +```jsonl +{"input": "Contact John Doe at john.doe@email.com or call (555) 123-4567.", "output": {"json_structures": [{"contact": {"name": "John Doe", "email": "john.doe@email.com", "phone": "(555) 123-4567"}}]}} +``` + +### Structure with List Fields + +```jsonl +{"input": "Product features include: wireless charging, water resistance, and face recognition.", "output": {"json_structures": [{"product": {"features": ["wireless charging", "water resistance", "face recognition"]}}]}} +``` + +### Structure with Mixed String and List Fields + +```jsonl +{"input": "iPhone 15 costs $999 and comes in blue, black, and white colors.", "output": {"json_structures": [{"product": {"name": "iPhone 15", "price": "$999", "colors": ["blue", "black", "white"]}}]}} +``` + +### Multiple Instances of Same Structure Type + +When the **same structure type** (parent name) appears multiple times in the text, each instance is a **separate dictionary** in the `json_structures` list: + +```jsonl +{"input": "We have two hotels available: Hotel Paradise with 4 stars, pool, and wifi for $150/night, and Budget Inn with 2 stars and parking for $80/night.", "output": {"json_structures": [{"hotel": {"name": "Hotel Paradise", "stars": "4", "amenities": ["pool", "wifi"], "price": "$150/night"}}, {"hotel": {"name": "Budget Inn", "stars": "2", "amenities": ["parking"], "price": "$80/night"}}]}} +``` + +**Note**: Both instances use the same parent key "hotel" but are separate objects in the list. This is how you represent multiple occurrences of the same structure type. + +Another example with three products: + +```jsonl +{"input": "Available products: iPhone 15 for $999, MacBook Pro for $1999, and AirPods for $199.", "output": {"json_structures": [{"product": {"name": "iPhone 15", "price": "$999"}}, {"product": {"name": "MacBook Pro", "price": "$1999"}}, {"product": {"name": "AirPods", "price": "$199"}}]}} +``` + +### Structure with Classification Fields (Choices) + +```jsonl +{"input": "Book a single room at Grand Hotel for 2 nights with breakfast included.", "output": {"json_structures": [{"booking": {"hotel": "Grand Hotel", "room_type": {"value": "single", "choices": ["single", "double", "suite"]}, "nights": "2", "meal_plan": {"value": "breakfast", "choices": ["none", "breakfast", "half-board", "full-board"]}}}]}} +``` + +### Structure with Multiple Choice Fields + +```jsonl +{"input": "Order a large pepperoni pizza for delivery, extra cheese.", "output": {"json_structures": [{"order": {"size": {"value": "large", "choices": ["small", "medium", "large", "xlarge"]}, "type": {"value": "pepperoni", "choices": ["cheese", "pepperoni", "veggie", "supreme"]}, "method": {"value": "delivery", "choices": ["pickup", "delivery", "dine-in"]}, "extras": ["extra cheese"]}}]}} +``` + +### Structure with Field Descriptions + +```jsonl +{"input": "Patient: Mary Wilson, Age: 45, diagnosed with hypertension, prescribed Lisinopril 10mg daily.", "output": {"json_structures": [{"medical_record": {"patient_name": "Mary Wilson", "age": "45", "diagnosis": "hypertension", "medication": "Lisinopril", "dosage": "10mg daily"}}], "json_descriptions": {"medical_record": {"patient_name": "Full name of the patient", "age": "Patient's age in years", "diagnosis": "Medical condition diagnosed", "medication": "Prescribed medication name", "dosage": "Medication dosage and frequency"}}}} +``` + +### Structure with Null/Empty Field Values + +```jsonl +{"input": "Product name is Widget X. Price not available.", "output": {"json_structures": [{"product": {"name": "Widget X", "price": "", "description": ""}}]}} +``` + +### Structure with Some Fields Missing + +```jsonl +{"input": "Contact Sarah at sarah@example.com", "output": {"json_structures": [{"contact": {"name": "Sarah", "email": "sarah@example.com", "phone": ""}}]}} +``` + +### Multiple Different Structure Types + +```jsonl +{"input": "John Doe works at TechCorp. Product ABC costs $50 with free shipping.", "output": {"json_structures": [{"employee": {"name": "John Doe", "company": "TechCorp"}}, {"product": {"name": "ABC", "price": "$50", "shipping": "free"}}]}} +``` + +### Structure with Only List Fields + +```jsonl +{"input": "Available colors: red, blue, green. Sizes: S, M, L, XL.", "output": {"json_structures": [{"options": {"colors": ["red", "blue", "green"], "sizes": ["S", "M", "L", "XL"]}}]}} +``` + +--- + +## 4. Relation Extraction + +Relations use flexible field structures. While "head" and "tail" are common, you can use ANY field names. + +**⚠️ Important**: The first occurrence of each relation type defines the field structure for ALL instances of that type. + +### Basic Relation (Head and Tail) + +```jsonl +{"input": "Alice manages the engineering team.", "output": {"relations": [{"manages": {"head": "Alice", "tail": "engineering team"}}]}} +{"input": "John works for Microsoft.", "output": {"relations": [{"works_for": {"head": "John", "tail": "Microsoft"}}]}} +``` + +### Multiple Instances - Same Field Structure + +All instances of the same relation type MUST have the same fields (determined by first occurrence): + +```jsonl +{"input": "Alice works for Google. Bob works for Microsoft. Charlie works for Amazon.", "output": {"relations": [{"works_for": {"head": "Alice", "tail": "Google"}}, {"works_for": {"head": "Bob", "tail": "Microsoft"}}, {"works_for": {"head": "Charlie", "tail": "Amazon"}}]}} +``` + +**Note**: All three "works_for" instances use the same fields (head, tail) as defined by the first occurrence. + +### Multiple Different Relation Types + +Different relation types can have different field structures: + +```jsonl +{"input": "John works for Apple Inc. and lives in San Francisco. Apple Inc. is located in Cupertino.", "output": {"relations": [{"works_for": {"head": "John", "tail": "Apple Inc."}}, {"lives_in": {"head": "John", "tail": "San Francisco"}}, {"located_in": {"head": "Apple Inc.", "tail": "Cupertino"}}]}} +``` + +**Note**: Each relation type ("works_for", "lives_in", "located_in") can independently define its own field structure. + +### Custom Field Names (Beyond Head/Tail) + +You can use custom field names - the first occurrence defines what fields to use: + +```jsonl +{"input": "Alice sent $100 to Bob. Charlie sent $50 to David.", "output": {"relations": [{"transaction": {"sender": "Alice", "recipient": "Bob", "amount": "$100"}}, {"transaction": {"sender": "Charlie", "recipient": "David", "amount": "$50"}}]}} +``` + +**Note**: First "transaction" uses sender/recipient/amount, so all "transaction" instances must use these same fields. + +### Relations with Additional Fields + +```jsonl +{"input": "John Smith is the CEO of TechCorp which is headquartered in Silicon Valley.", "output": {"relations": [{"employment": {"head": "John Smith", "tail": "TechCorp", "role": "CEO"}}, {"located_in": {"head": "TechCorp", "tail": "Silicon Valley"}}]}} +``` + +### Relations Combined with Entities + +```jsonl +{"input": "Elon Musk founded SpaceX in 2002. SpaceX is located in Hawthorne.", "output": {"entities": {"person": ["Elon Musk"], "organization": ["SpaceX"], "location": ["Hawthorne"], "date": ["2002"]}, "relations": [{"founded": {"head": "Elon Musk", "tail": "SpaceX"}}, {"located_in": {"head": "SpaceX", "tail": "Hawthorne"}}]}} +``` + +### Empty Relations (Negative Example) + +```jsonl +{"input": "The weather is nice today.", "output": {"relations": []}} +``` + +### Bidirectional Relations + +```jsonl +{"input": "Alice and Bob are colleagues.", "output": {"relations": [{"colleague_of": {"head": "Alice", "tail": "Bob"}}, {"colleague_of": {"head": "Bob", "tail": "Alice"}}]}} +``` + +### Field Consistency: Relations vs JSON Structures + +**Key Difference**: + +- **Relations**: First occurrence defines field structure for ALL instances of that relation type + - All "works_for" relations must have same fields + - Enforced consistency per relation type + +- **JSON Structures**: Fields can vary between instances of the same parent type + - Uses union of all fields across instances + - More flexible - instances can have different subsets of fields + +**Example - Relations (Strict Consistency)**: +```jsonl +{"input": "Alice works for Google. Bob works for Microsoft.", "output": {"relations": [{"works_for": {"head": "Alice", "tail": "Google"}}, {"works_for": {"head": "Bob", "tail": "Microsoft"}}]}} +``` +βœ“ Valid: Both "works_for" have same fields (head, tail) + +**Example - JSON Structures (Flexible Fields)**: +```jsonl +{"input": "Product A costs $10. Product B costs $20 and weighs 5kg.", "output": {"json_structures": [{"product": {"name": "A", "price": "$10"}}, {"product": {"name": "B", "price": "$20", "weight": "5kg"}}]}} +``` +βœ“ Valid: Second instance has extra "weight" field - this is allowed for json_structures + +--- + +## 5. Combined Multi-Task Examples + +### Entities + Classifications + +```jsonl +{"input": "Apple Inc. announced record profits. This is great news for investors.", "output": {"entities": {"organization": ["Apple Inc."]}, "classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}]}} +``` + +### Entities + JSON Structures + +```jsonl +{"input": "Contact John Doe at john@example.com. He works at TechCorp.", "output": {"entities": {"person": ["John Doe"], "organization": ["TechCorp"]}, "json_structures": [{"contact": {"name": "John Doe", "email": "john@example.com", "company": "TechCorp"}}]}} +``` + +### Entities + Relations + +```jsonl +{"input": "Elon Musk founded SpaceX in 2002. SpaceX is located in Hawthorne.", "output": {"entities": {"person": ["Elon Musk"], "organization": ["SpaceX"], "location": ["Hawthorne"], "date": ["2002"]}, "relations": [{"founded": {"head": "Elon Musk", "tail": "SpaceX", "year": "2002"}}, {"located_in": {"head": "SpaceX", "tail": "Hawthorne"}}]}} +``` + +### Classifications + JSON Structures + +```jsonl +{"input": "Premium subscription for $99/month includes unlimited access. Great value!", "output": {"classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}], "json_structures": [{"subscription": {"tier": "Premium", "price": "$99/month", "features": ["unlimited access"]}}]}} +``` + +### Entities + Classifications + JSON Structures + +```jsonl +{"input": "Apple CEO Tim Cook unveiled iPhone 15 for $999. Analysts are optimistic.", "output": {"entities": {"person": ["Tim Cook"], "organization": ["Apple"], "product": ["iPhone 15"]}, "classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}], "json_structures": [{"product_announcement": {"company": "Apple", "product": "iPhone 15", "price": "$999", "presenter": "Tim Cook"}}]}} +``` + +### Entities + Relations + Classifications + +```jsonl +{"input": "Sarah founded TechStart in 2020. The company is doing exceptionally well.", "output": {"entities": {"person": ["Sarah"], "organization": ["TechStart"], "date": ["2020"]}, "relations": [{"founded": {"head": "Sarah", "tail": "TechStart", "year": "2020"}}], "classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}]}} +``` + +### All Four Tasks Combined + +```jsonl +{"input": "Breaking: Apple announces new iPhone 15 with improved camera. Analysts are optimistic about sales projections.", "output": {"entities": {"company": ["Apple"], "product": ["iPhone 15"]}, "classifications": [{"task": "sentiment", "labels": ["positive", "negative", "neutral"], "true_label": ["positive"]}, {"task": "category", "labels": ["technology", "business", "sports", "entertainment"], "true_label": ["technology"]}], "json_structures": [{"news_article": {"company": "Apple", "product": "iPhone 15", "feature": "improved camera", "analyst_view": "optimistic"}}], "relations": [{"product_of": {"head": "iPhone 15", "tail": "Apple"}}]}} +``` + +### Multi-Task with Descriptions + +```jsonl +{"input": "Dr. Johnson prescribed medication X for condition Y. Patient shows improvement.", "output": {"entities": {"person": ["Dr. Johnson"], "medication": ["medication X"], "condition": ["condition Y"]}, "entity_descriptions": {"person": "Healthcare provider names", "medication": "Prescribed drugs", "condition": "Medical conditions"}, "classifications": [{"task": "patient_status", "labels": ["improving", "stable", "declining"], "true_label": ["improving"], "label_descriptions": {"improving": "Patient condition getting better", "stable": "No change in condition", "declining": "Patient condition worsening"}}], "json_structures": [{"prescription": {"doctor": "Dr. Johnson", "medication": "medication X", "condition": "condition Y"}}], "json_descriptions": {"prescription": {"doctor": "Prescribing physician", "medication": "Prescribed drug name", "condition": "Diagnosed condition"}}}} +``` + +### Partial Multi-Task (Some Tasks Empty) + +**Note**: While you can include empty dictionaries/lists for some tasks, at least one task must have content. + +```jsonl +{"input": "The weather forecast predicts rain tomorrow.", "output": {"entities": {}, "classifications": [{"task": "weather", "labels": ["sunny", "rainy", "cloudy", "snowy"], "true_label": ["rainy"]}], "json_structures": []}} +``` + +This is valid because it has a classification task. However, if all tasks were empty, it would fail validation. + +--- + +## 6. Format Edge Cases + +### Completely Empty Output + +**⚠️ Note**: Examples must have at least one task (entities, classifications, structures, or relations). Completely empty outputs are not valid training examples. + +```jsonl +{"input": "Random text with no specific information.", "output": {"entities": {}, "classifications": [], "json_structures": [], "relations": []}} +``` + +This format will fail validation. Each example must contain at least one annotation. + +### Empty Entities Dictionary + +**⚠️ Note**: While an empty entities dictionary is syntactically valid, examples must have at least one task. If you only have empty entities, add at least one other task (classification, structure, or relation). + +```jsonl +{"input": "The weather is nice today.", "output": {"entities": {}, "classifications": [{"task": "sentiment", "labels": ["positive", "negative"], "true_label": ["positive"]}]}} +``` + +### Empty Classifications List + +**⚠️ Note**: While an empty classifications list is syntactically valid, examples must have at least one task. If you only have empty classifications, add at least one other task. + +```jsonl +{"input": "Some generic text.", "output": {"classifications": [], "entities": {"location": ["text"]}}} +``` + +### Very Long Label Lists + +```jsonl +{"input": "Sample text for many labels.", "output": {"classifications": [{"task": "topic", "labels": ["label1", "label2", "label3", "label4", "label5", "label6", "label7", "label8", "label9", "label10", "label11", "label12", "label13", "label14", "label15", "label16", "label17", "label18", "label19", "label20"], "true_label": ["label5"]}]}} +``` + +### Very Short Text + +```jsonl +{"input": "Yes.", "output": {"classifications": [{"task": "response", "labels": ["yes", "no", "maybe"], "true_label": ["yes"]}]}} +{"input": "OK", "output": {"entities": {}}} +``` + +### Special Characters in Labels + +```jsonl +{"input": "The C++ programming language.", "output": {"entities": {"programming_language": ["C++"]}}} +{"input": "Use the @ symbol for mentions.", "output": {"entities": {"symbol": ["@"]}}} +``` + +### Special Characters in Values + +```jsonl +{"input": "Price is $1,299.99 (including tax).", "output": {"json_structures": [{"pricing": {"amount": "$1,299.99", "note": "(including tax)"}}]}} +``` + +### Unicode and Non-ASCII Characters + +```jsonl +{"input": "CafΓ© MΓΌnchΓ«n serves crΓ¨me brΓ»lΓ©e.", "output": {"entities": {"location": ["CafΓ© MΓΌnchΓ«n"], "food": ["crΓ¨me brΓ»lΓ©e"]}}} +{"input": "东京 Tokyo is the capital.", "output": {"entities": {"location": ["东京", "Tokyo"]}}} +``` + +### Quotes and Escaping + +```jsonl +{"input": "He said \"hello\" to me.", "output": {"entities": {"quote": ["\"hello\""]}}} +``` + +### Newlines in Text + +```jsonl +{"input": "First line.\nSecond line.", "output": {"entities": {"text": ["First line", "Second line"]}}} +``` + +### Numbers as Strings vs Entity Names + +```jsonl +{"input": "Room 123 on floor 4.", "output": {"json_structures": [{"location": {"room": "123", "floor": "4"}}]}} +``` + +### Boolean-like Values + +```jsonl +{"input": "Status is active, notifications enabled.", "output": {"json_structures": [{"settings": {"status": "active", "notifications": "enabled"}}]}} +``` + +### Empty String Values + +```jsonl +{"input": "Name: John, Age: unknown", "output": {"json_structures": [{"person": {"name": "John", "age": ""}}]}} +``` + +### Multiple Empty Lines in JSONL + +```jsonl +{"input": "First example.", "output": {"entities": {"type": ["example"]}}} +{"input": "Second example.", "output": {"entities": {"type": ["example"]}}} +``` + +--- + +## Schema Component Reference + +### entities +- **Type**: `dict[str, list[str]]` +- **Format**: `{"entity_type": ["mention1", "mention2", ...]}` +- **Example**: `{"person": ["John", "Alice"], "location": ["NYC"]}` + +### entity_descriptions +- **Type**: `dict[str, str]` +- **Format**: `{"entity_type": "description text"}` +- **Example**: `{"person": "Names of people", "location": "Geographic places"}` + +### classifications +- **Type**: `list[dict]` +- **Required fields**: `task`, `labels`, `true_label` +- **Optional fields**: `multi_label`, `prompt`, `examples`, `label_descriptions` +- **Example**: `[{"task": "sentiment", "labels": ["pos", "neg"], "true_label": ["pos"]}]` + +### json_structures +- **Type**: `list[dict]` +- **Single instance**: `[{"parent_name": {"field1": "value1", "field2": ["list", "values"]}}]` +- **Multiple instances (same parent)**: `[{"parent": {...}}, {"parent": {...}}]` - Same parent key, separate dicts +- **Multiple types**: `[{"parent1": {...}}, {"parent2": {...}}]` - Different parent keys +- **Choice format**: `{"field": {"value": "selected", "choices": ["opt1", "opt2"]}}` +- **Example**: `[{"product": {"name": "Item", "price": "$10"}}, {"product": {"name": "Item2", "price": "$20"}}]` + +### json_descriptions +- **Type**: `dict[str, dict[str, str]]` +- **Format**: `{"parent": {"field": "description"}}` +- **Example**: `{"product": {"name": "Product name", "price": "Cost in USD"}}` + +### relations +- **Type**: `list[dict]` +- **Standard format**: `[{"relation_name": {"head": "entity1", "tail": "entity2"}}]` +- **With custom fields**: `[{"relation_name": {"sender": "A", "recipient": "B", "amount": "$100"}}]` +- **Example**: `[{"works_for": {"head": "John", "tail": "Company"}}, {"founded": {"head": "Alice", "tail": "StartupX"}}]` +- **⚠️ Field constraint**: First occurrence of each relation type defines field structure for ALL instances of that type +- **Note**: While "head" and "tail" are common, you can use ANY field names - just keep them consistent per relation type + +--- + +## Tips for Dataset Creation + +1. **Use diverse examples** to improve model generalization +2. **Include edge cases** - but remember each example must have at least one task +3. **Provide descriptions** when possible to improve accuracy +4. **Balance your classes** in classification tasks +5. **Use realistic text** that matches your target domain +6. **Include multiple instances** for JSON structures when applicable +7. **For negative examples**, include at least one task (e.g., empty entities but a classification, or empty classifications but entities) +8. **Mix task types** to train multi-task capabilities +9. **Use consistent formatting** for similar examples +10. **Include special characters** to ensure robust handling +11. **Validate your dataset** using `TrainingDataset.validate(strict=True)` to catch annotation errors early +12. **Check relation consistency** using `validate_relation_consistency()` to ensure all relation types have consistent field structures + +## Validation Checklist + +Make sure your JSONL file is valid by checking: +- [ ] Each line is valid JSON +- [ ] Required fields (`input`/`output` or `text`/`schema`) are present +- [ ] **At least one task is present** (entities, classifications, structures, or relations) +- [ ] Schema structure matches the expected format +- [ ] Entity spans exist in the input text (entities can be found in the input) - checked in strict validation mode +- [ ] Classification labels are from the defined label set +- [ ] `true_label` is a list or string (string format is converted to list internally) +- [ ] For multi-label classification, `multi_label` is set to `true` when multiple labels are provided +- [ ] JSON structure fields match between instances of the same parent (flexible - union of fields is used) +- [ ] **Relation field consistency**: All instances of the same relation type use the same field names (determined by first occurrence) +- [ ] No trailing commas in JSON objects +- [ ] Special characters are properly escaped +- [ ] File encoding is UTF-8 + +### Validation Modes + +The implementation supports two validation modes: + +- **Standard validation**: Checks format correctness, required fields, label consistency +- **Strict validation**: Additionally checks that entity mentions and relation values exist in the input text (case-insensitive substring matching) + +Use strict validation during dataset creation to catch annotation errors early. diff --git a/packages/GLiNER2/tutorial/9-training.md b/packages/GLiNER2/tutorial/9-training.md new file mode 100644 index 0000000..6543e47 --- /dev/null +++ b/packages/GLiNER2/tutorial/9-training.md @@ -0,0 +1,1296 @@ +# GLiNER2 Training Tutorial + +Complete guide to training GLiNER2 models for entity extraction, classification, structured data extraction, and relation extraction. + +## Table of Contents + +1. [Quick Start](#quick-start) +2. [End-to-End Training Examples](#end-to-end-training-examples) +3. [Data Preparation](#data-preparation) +4. [Training Configuration](#training-configuration) +5. [LoRA Training](#lora-training) +6. [Advanced Topics](#advanced-topics) +7. [Troubleshooting](#troubleshooting) + +--- + +## Quick Start + +### Minimal Example + +```python +from gliner2 import GLiNER2 +from gliner2.training.data import InputExample +from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig + +# 1. Create training examples +examples = [ + InputExample( + text="John works at Google in California.", + entities={"person": ["John"], "company": ["Google"], "location": ["California"]} + ), + InputExample( + text="Apple released iPhone 15.", + entities={"company": ["Apple"], "product": ["iPhone 15"]} + ), +] + +# 2. Initialize model and config +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") +config = TrainingConfig( + output_dir="./output", + num_epochs=10, + batch_size=8, + encoder_lr=1e-5, + task_lr=5e-4 +) + +# 3. Train +trainer = GLiNER2Trainer(model, config) +trainer.train(train_data=examples) +``` + +### Quick Start with JSONL + +```python +# Create train.jsonl file with format: +# {"input": "text here", "output": {"entities": {"type": ["mention1", "mention2"]}}} + +from gliner2 import GLiNER2 +from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig + +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") +config = TrainingConfig(output_dir="./output", num_epochs=10) + +trainer = GLiNER2Trainer(model, config) +trainer.train(train_data="train.jsonl") +``` + +--- + +## End-to-End Training Examples + +### Example 1: Complete NER Training Pipeline + +```python +from gliner2 import GLiNER2 +from gliner2.training.data import InputExample, TrainingDataset +from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig + +# Step 1: Prepare training data +train_examples = [ + InputExample( + text="Tim Cook is the CEO of Apple Inc., based in Cupertino, California.", + entities={ + "person": ["Tim Cook"], + "company": ["Apple Inc."], + "location": ["Cupertino", "California"] + }, + entity_descriptions={ + "person": "Full name of a person", + "company": "Business organization name", + "location": "Geographic location or place" + } + ), + InputExample( + text="OpenAI released GPT-4 in March 2023. The model was developed in San Francisco.", + entities={ + "company": ["OpenAI"], + "model": ["GPT-4"], + "date": ["March 2023"], + "location": ["San Francisco"] + }, + entity_descriptions={ + "model": "Machine learning model or AI system", + "date": "Date or time reference" + } + ), + # Add more examples... +] + +# Step 2: Create and validate dataset +train_dataset = TrainingDataset(train_examples) +train_dataset.validate(strict=True, raise_on_error=True) +train_dataset.print_stats() + +# Step 3: Split into train/validation +train_data, val_data, _ = train_dataset.split( + train_ratio=0.8, + val_ratio=0.2, + test_ratio=0.0, + shuffle=True, + seed=42 +) + +# Step 4: Save datasets +train_data.save("train.jsonl") +val_data.save("val.jsonl") + +# Step 5: Configure training +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") +config = TrainingConfig( + output_dir="./ner_model", + experiment_name="ner_training", + num_epochs=15, + batch_size=16, + encoder_lr=1e-5, + task_lr=5e-4, + warmup_ratio=0.1, + scheduler_type="cosine", + fp16=True, + eval_strategy="epoch", # Evaluates and saves at end of each epoch + save_best=True, + early_stopping=True, + early_stopping_patience=3, + report_to_wandb=True, + wandb_project="ner_training" +) + +# Step 6: Train +trainer = GLiNER2Trainer(model, config) +results = trainer.train( + train_data=train_data, + eval_data=val_data +) + +print(f"Training completed!") +print(f"Best validation loss: {results['best_metric']:.4f}") +print(f"Total steps: {results['total_steps']}") +print(f"Training time: {results['total_time_seconds']/60:.1f} minutes") + +# Step 7: Load best model for inference +best_model = GLiNER2.from_pretrained("./ner_model/best") +``` + +### Example 2: Multi-Task Training (NER + Classification + Relations) + +```python +from gliner2 import GLiNER2 +from gliner2.training.data import InputExample, Classification, Relation +from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig + +# Create multi-task examples +examples = [ + InputExample( + text="John Smith works at Google in California. The company is thriving and expanding rapidly.", + entities={ + "person": ["John Smith"], + "company": ["Google"], + "location": ["California"] + }, + classifications=[ + Classification( + task="sentiment", + labels=["positive", "negative", "neutral"], + true_label="positive" + ) + ], + relations=[ + Relation("works_at", head="John Smith", tail="Google"), + Relation("located_in", head="Google", tail="California") + ] + ), + # More examples... +] + +# Train multi-task model +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") +config = TrainingConfig( + output_dir="./multitask_model", + num_epochs=20, + batch_size=16, + encoder_lr=1e-5, + task_lr=5e-4 +) + +trainer = GLiNER2Trainer(model, config) +trainer.train(train_data=examples) +``` + +### Example 3: Domain-Specific Fine-tuning (Medical NER) + +```python +from gliner2 import GLiNER2 +from gliner2.training.data import InputExample, TrainingDataset +from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig + +# Medical domain examples +medical_examples = [ + InputExample( + text="Patient presented with hypertension and type 2 diabetes mellitus.", + entities={ + "condition": ["hypertension", "type 2 diabetes mellitus"] + }, + entity_descriptions={ + "condition": "Medical condition, disease, or symptom" + } + ), + InputExample( + text="Prescribed metformin 500mg twice daily. Patient to follow up in 2 weeks.", + entities={ + "medication": ["metformin"], + "dosage": ["500mg"], + "frequency": ["twice daily"], + "duration": ["2 weeks"] + }, + entity_descriptions={ + "medication": "Prescribed drug or medication name", + "dosage": "Amount or strength of medication", + "frequency": "How often medication is taken", + "duration": "Time period for treatment" + } + ), + # More medical examples... +] + +# Fine-tune on medical domain +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") +config = TrainingConfig( + output_dir="./medical_ner", + num_epochs=20, + batch_size=16, + encoder_lr=5e-6, # Lower LR for fine-tuning + task_lr=1e-4, + warmup_ratio=0.05 +) + +trainer = GLiNER2Trainer(model, config) +trainer.train(train_data=medical_examples) +``` + +--- + +## Data Preparation + +### Supported Data Formats + +GLiNER2 supports multiple input formats: + +1. **JSONL Files** (recommended for large datasets) +2. **InputExample List** (recommended for programmatic creation) +3. **TrainingDataset Object** +4. **Raw Dict List** + +```python +# Format 1: JSONL file(s) +trainer.train(train_data="train.jsonl") +trainer.train(train_data=["train1.jsonl", "train2.jsonl"]) + +# Format 2: InputExample list +examples = [InputExample(...), ...] +trainer.train(train_data=examples) + +# Format 3: TrainingDataset +dataset = TrainingDataset.load("train.jsonl") +trainer.train(train_data=dataset) + +# Format 4: Raw dicts +raw_data = [{"input": "...", "output": {...}}, ...] +trainer.train(train_data=raw_data) +``` + +### Creating Training Examples + +#### Entity Extraction + +```python +from gliner2.training.data import InputExample + +# Simple entity extraction +example = InputExample( + text="John Smith works at Google in San Francisco.", + entities={ + "person": ["John Smith"], + "company": ["Google"], + "location": ["San Francisco"] + } +) + +# With entity descriptions (improves model understanding) +example = InputExample( + text="BERT was developed by Google AI.", + entities={ + "model": ["BERT"], + "organization": ["Google AI"] + }, + entity_descriptions={ + "model": "Machine learning model or architecture", + "organization": "Company, research lab, or institution" + } +) +``` + +#### Classification + +```python +from gliner2.training.data import InputExample, Classification + +# Single-label classification +example = InputExample( + text="This movie is amazing! Best film of the year.", + classifications=[ + Classification( + task="sentiment", + labels=["positive", "negative", "neutral"], + true_label="positive" + ) + ] +) + +# Multi-label classification +example = InputExample( + text="Python tutorial covers machine learning and web development.", + classifications=[ + Classification( + task="topic", + labels=["programming", "machine_learning", "web_dev", "data_science"], + true_label=["programming", "machine_learning", "web_dev"], + multi_label=True + ) + ] +) +``` + +#### Structured Data Extraction + +```python +from gliner2.training.data import InputExample, Structure, ChoiceField + +# Simple structure +example = InputExample( + text="iPhone 15 Pro costs $999 and comes in titanium color.", + structures=[ + Structure( + "product", + name="iPhone 15 Pro", + price="$999", + color="titanium" + ) + ] +) + +# With choice fields +example = InputExample( + text="Order #12345 for laptop shipped on 2024-01-15.", + structures=[ + Structure( + "order", + order_id="12345", + product="laptop", + date="2024-01-15", + status=ChoiceField( + value="shipped", + choices=["pending", "processing", "shipped", "delivered"] + ) + ) + ] +) +``` + +#### Relation Extraction + +```python +from gliner2.training.data import InputExample, Relation + +# Binary relations +example = InputExample( + text="Elon Musk founded SpaceX in 2002.", + relations=[ + Relation("founded", head="Elon Musk", tail="SpaceX"), + Relation("founded_in", head="SpaceX", tail="2002") + ] +) + +# Custom relation fields +example = InputExample( + text="Exercise improves mental health.", + relations=[ + Relation( + "causal_relation", + cause="exercise", + effect="mental health", + direction="positive" + ) + ] +) +``` + +### Data Validation + +```python +from gliner2.training.data import TrainingDataset + +# Load and validate dataset +dataset = TrainingDataset.load("train.jsonl") + +# Strict validation (checks entity spans exist in text) +try: + dataset.validate(strict=True, raise_on_error=True) +except ValidationError as e: + print(f"Validation failed: {e}") + +# Get validation report +report = dataset.validate(raise_on_error=False) +print(f"Valid: {report['valid']}, Invalid: {report['invalid']}") + +# Print statistics +dataset.print_stats() +``` + +### Data Splitting and Management + +```python +from gliner2.training.data import TrainingDataset + +# Load full dataset +dataset = TrainingDataset.load("full_data.jsonl") + +# Split into train/val/test +train_data, val_data, test_data = dataset.split( + train_ratio=0.8, + val_ratio=0.1, + test_ratio=0.1, + shuffle=True, + seed=42 +) + +# Save splits +train_data.save("train.jsonl") +val_data.save("val.jsonl") +test_data.save("test.jsonl") + +# Filter and sample +entity_only = dataset.filter(lambda ex: len(ex.entities) > 0) +small_sample = dataset.sample(n=100, seed=42) + +# Combine multiple datasets +dataset1 = TrainingDataset.load("dataset1.jsonl") +dataset2 = TrainingDataset.load("dataset2.jsonl") +combined = TrainingDataset() +combined.add_many(dataset1.examples) +combined.add_many(dataset2.examples) +``` + +--- + +## Training Configuration + +### Checkpoint Saving Behavior + +**Important**: Checkpoints are automatically saved when evaluation runs. The `eval_strategy` parameter controls both when to evaluate and when to save checkpoints: + +- `eval_strategy="steps"` (default): Evaluate and save every `eval_steps` steps +- `eval_strategy="epoch"`: Evaluate and save at the end of each epoch +- `eval_strategy="no"`: No evaluation or checkpoint saving (except final checkpoint) + +The best checkpoint (based on `metric_for_best`) is always saved separately when `save_best=True`. + +### Basic Configuration + +```python +from gliner2.training.trainer import TrainingConfig + +config = TrainingConfig( + # Output + output_dir="./output", + experiment_name="my_experiment", + + # Training + num_epochs=10, + batch_size=32, + gradient_accumulation_steps=1, + + # Learning rates + encoder_lr=1e-5, + task_lr=5e-4, + + # Optimization + weight_decay=0.01, + max_grad_norm=1.0, + scheduler_type="linear", + warmup_ratio=0.1, + + # Mixed precision + fp16=True, + + # Checkpointing & Evaluation (saves when evaluating) + eval_strategy="epoch", # "epoch", "steps", or "no" + eval_steps=500, # Used when eval_strategy="steps" + save_best=True, + + # Logging + logging_steps=50, + report_to_wandb=False, + wandb_project=None +) +``` + +### Common Configurations + +**Fast Prototyping:** +```python +config = TrainingConfig( + output_dir="./quick_test", + num_epochs=3, + batch_size=16, + encoder_lr=1e-5, + task_lr=5e-4, + max_train_samples=100, + eval_strategy="no" +) +``` + +**Production Training:** +```python +config = TrainingConfig( + output_dir="./production_model", + num_epochs=20, + batch_size=32, + gradient_accumulation_steps=2, + encoder_lr=5e-6, + task_lr=1e-4, + weight_decay=0.01, + warmup_ratio=0.1, + scheduler_type="cosine", + fp16=True, + eval_strategy="steps", # Evaluates and saves every N steps + eval_steps=500, + save_total_limit=5, + save_best=True, + early_stopping=True, + early_stopping_patience=5, + report_to_wandb=True, + wandb_project="gliner2-production" +) +``` + +**Memory-Optimized:** +```python +config = TrainingConfig( + output_dir="./large_model", + num_epochs=10, + batch_size=8, + gradient_accumulation_steps=8, + gradient_checkpointing=True, + fp16=True, + encoder_lr=1e-6, + task_lr=5e-5, + max_grad_norm=0.5, + num_workers=2 +) +``` + +### Complete Configuration Reference + +```python +config = TrainingConfig( + # Output + output_dir="./output", + experiment_name="gliner2", + + # Training steps + num_epochs=10, + max_steps=-1, + + # Batch size + batch_size=32, + eval_batch_size=64, + gradient_accumulation_steps=1, + + # Learning rates + encoder_lr=1e-5, + task_lr=5e-4, + + # Optimizer + weight_decay=0.01, + adam_beta1=0.9, + adam_beta2=0.999, + adam_epsilon=1e-8, + max_grad_norm=1.0, + + # Learning rate schedule + scheduler_type="linear", # "linear", "cosine", "cosine_restarts", "constant" + warmup_ratio=0.1, + warmup_steps=0, + num_cycles=0.5, + + # Mixed precision + fp16=True, + bf16=False, + + # Checkpointing & Evaluation (saves when evaluating) + eval_strategy="steps", # "epoch", "steps", or "no" (default: "steps") + eval_steps=500, # Evaluate and save every N steps (when eval_strategy="steps") + save_total_limit=3, + save_best=True, + metric_for_best="eval_loss", + greater_is_better=False, + + # Logging + logging_steps=50, # Update progress bar metrics every N steps + # Metrics (loss, learning rate, throughput) are shown in the progress bar + logging_first_step=True, + report_to_wandb=False, # Enable W&B logging for experiment tracking + wandb_project=None, + wandb_entity=None, + wandb_run_name=None, + wandb_tags=[], + wandb_notes=None, + + # Early stopping + early_stopping=False, + early_stopping_patience=3, + early_stopping_threshold=0.0, + + # DataLoader + num_workers=4, + pin_memory=True, + prefetch_factor=2, + + # Other + seed=42, + deterministic=False, + gradient_checkpointing=False, + max_train_samples=-1, + max_eval_samples=-1, + validate_data=True, + strict_validation=False, + + # LoRA (see LoRA section) + use_lora=False, + lora_r=16, + lora_alpha=32.0, + lora_dropout=0.0, + lora_target_modules=["encoder", "span_rep", "classifier", "count_embed", "count_pred"], + save_adapter_only=True, +) +``` + +--- + +## LoRA Training + +LoRA (Low-Rank Adaptation) enables parameter-efficient fine-tuning by training only a small number of additional parameters while keeping the base model frozen. + +> πŸ“š **For a comprehensive guide on training and using multiple LoRA adapters**, see [Tutorial 10: LoRA Adapters - Multi-Domain Inference](./10-lora_adapters.md) + +### Why Use LoRA? + +- **Memory Efficient**: Train with 10-100x fewer parameters +- **Faster Training**: Fewer gradients to compute +- **Multiple Adapters**: Train different adapters for different tasks +- **Easy Deployment**: Checkpoints contain merged weights (ready for inference) +- **Granular Control**: Target specific layers or entire module groups + +### LoRA Module Groups + +GLiNER2 supports both coarse-grained (module groups) and fine-grained (specific layers) control: + +**Module Groups** (apply to entire modules): +- `"encoder"` - All encoder layers (query, key, value, dense) +- `"span_rep"` - All linear layers in span representation +- `"classifier"` - All linear layers in classifier head +- `"count_embed"` - All linear layers in count embedding +- `"count_pred"` - All linear layers in count prediction + +**Specific Encoder Layers** (fine-grained control): +- `"encoder.query"` - Only query projection layers +- `"encoder.key"` - Only key projection layers +- `"encoder.value"` - Only value projection layers +- `"encoder.dense"` - Only dense (FFN) layers + +**Default**: All modules (`["encoder", "span_rep", "classifier", "count_embed", "count_pred"]`) for maximum adaptation. + +### Basic LoRA Training + +```python +from gliner2 import GLiNER2 +from gliner2.training.trainer import GLiNER2Trainer, TrainingConfig + +# Load base model +model = GLiNER2.from_pretrained("fastino/gliner2-base-v1") + +# Configure LoRA training +config = TrainingConfig( + output_dir="./output_lora", + num_epochs=10, + batch_size=16, + + # Enable LoRA + use_lora=True, + lora_r=16, # Rank (higher = more params, better approximation) + lora_alpha=32, # Scaling factor (typically 2*r) + lora_dropout=0.1, # Dropout for regularization + lora_target_modules=["encoder"], # All encoder layers (query, key, value, dense) + save_adapter_only=True, # Save only adapter weights (not full model) + + # Learning rate (task_lr used for LoRA + task heads when LoRA enabled) + task_lr=5e-4, + # encoder_lr is ignored when LoRA is enabled + + # Other settings + fp16=True, + eval_strategy="epoch", # Evaluates and saves at end of each epoch + save_best=True +) + +# Train with LoRA +trainer = GLiNER2Trainer(model, config) +results = trainer.train(train_data="train.jsonl", eval_data="val.jsonl") + +# Checkpoints contain merged weights (ready for inference) +best_model = GLiNER2.from_pretrained("./output_lora/best") +``` + +### LoRA Configuration Parameters + +```python +config = TrainingConfig( + # Enable LoRA + use_lora=True, + + # LoRA rank (r): Controls the rank of low-rank decomposition + # Higher r = more parameters but better approximation + # Typical values: 4, 8, 16, 32, 64 + # Start with 8 or 16 for most tasks + lora_r=16, + + # LoRA alpha: Scaling factor for LoRA updates + # Final scaling is alpha/r + # Typical values: 8, 16, 32 (often 2*r) + # Common practice: alpha = 2 * r + lora_alpha=32, + + # LoRA dropout: Dropout probability applied to LoRA path + # Helps prevent overfitting + # Typical values: 0.0, 0.05, 0.1 + lora_dropout=0.1, + + # Target modules: Which module groups to apply LoRA to + # Module groups: + # - "encoder": All encoder layers (query, key, value, dense) + # - "encoder.query": Only query projection layers in encoder + # - "encoder.key": Only key projection layers in encoder + # - "encoder.value": Only value projection layers in encoder + # - "encoder.dense": Only dense (FFN) layers in encoder + # - "span_rep": All linear layers in span representation module + # - "classifier": All linear layers in classifier head + # - "count_embed": All linear layers in count embedding + # - "count_pred": All linear layers in count prediction + # + # Common configurations: + # - ["encoder"]: Encoder only (query, key, value, dense) - good starting point + # - ["encoder.query", "encoder.key", "encoder.value"]: Attention only - memory efficient + # - ["encoder.dense"]: FFN only - alternative approach + # - ["encoder", "span_rep", "classifier"]: Encoder + task heads - better performance + # - ["encoder", "span_rep", "classifier", "count_embed", "count_pred"]: All modules (default) - maximum adaptation + # + # Default: All modules for maximum adaptation capacity + lora_target_modules=["encoder", "span_rep", "classifier", "count_embed", "count_pred"], + + # Save adapter only (recommended) + # When True: saves only LoRA adapter weights (~2-10 MB) + # When False: saves full model with merged weights (~100-500 MB) + save_adapter_only=True, + + # Learning rate for LoRA parameters + # When LoRA is enabled, task_lr is used for both LoRA and task-specific heads + task_lr=5e-4, # Typical: 1e-4 to 1e-3 +) +``` + +### LoRA Training Examples + +**Example 1: Memory-Constrained Training** + +```python +# Train on GPU with limited memory +config = TrainingConfig( + output_dir="./lora_small_memory", + use_lora=True, + lora_r=8, # Smaller rank for less memory + lora_alpha=16, + batch_size=32, # Can use larger batch with LoRA + gradient_accumulation_steps=1, + task_lr=5e-4, + fp16=True +) + +trainer = GLiNER2Trainer(model, config) +trainer.train(train_data="train.jsonl") +``` + +**Example 2: High-Performance LoRA** + +```python +# Use higher rank and include task heads for better performance +config = TrainingConfig( + output_dir="./lora_high_perf", + use_lora=True, + lora_r=32, # Higher rank + lora_alpha=64, + lora_dropout=0.05, + lora_target_modules=["encoder", "span_rep", "classifier"], # Encoder + task heads + batch_size=16, + task_lr=1e-3, # Slightly higher LR + num_epochs=15 +) + +trainer = GLiNER2Trainer(model, config) +trainer.train(train_data="train.jsonl") +``` + +**Example 3: Domain Adaptation with LoRA** + +```python +# Fine-tune for specific domain with LoRA +config = TrainingConfig( + output_dir="./lora_medical", + use_lora=True, + lora_r=16, + lora_alpha=32, + lora_dropout=0.1, + batch_size=16, + task_lr=5e-4, + num_epochs=20, + warmup_ratio=0.05 +) + +trainer = GLiNER2Trainer(model, config) +trainer.train(train_data=medical_examples) +``` + +### LoRA vs Full Fine-tuning + +| Aspect | LoRA | Full Fine-tuning | +|--------|------|------------------| +| Trainable Parameters | ~0.1-1% of model | 100% of model | +| Memory Usage | Low | High | +| Training Speed | Fast | Slower | +| Checkpoint Size | Small (merged) | Large | +| Performance | Good (often comparable) | Best | +| Use Case | Limited data, multiple tasks | Large datasets, single task | + +### LoRA Best Practices + +1. **Start with default settings**: `r=16`, `alpha=32`, `dropout=0.1` +2. **Increase rank if needed**: If performance is insufficient, try `r=32` or `r=64` +3. **Use dropout for regularization**: Set `lora_dropout=0.1` to prevent overfitting +4. **Target attention layers first**: Start with `["query", "key", "value"]`, add `"dense"` if needed +5. **Higher learning rate**: LoRA typically works well with `task_lr=5e-4` to `1e-3` +6. **Checkpoint merging**: Checkpoints automatically contain merged weights (ready for inference) + +### Loading Checkpoints + +```python +# Load model from checkpoint (for inference or continued training) +trainer = GLiNER2Trainer(model, config) + +# Load checkpoint (weights are merged in checkpoint) +trainer.load_checkpoint("./output_lora/checkpoint-1000") + +# Continue training (LoRA will be re-applied if use_lora=True) +trainer.train(train_data="train.jsonl") +``` + +**Note**: Checkpoints do not save optimizer/scheduler state. Training always starts fresh, but model weights are loaded. + +--- + +## Advanced Topics + +### Custom Metrics + +```python +def compute_metrics(model, eval_dataset): + """Custom metric computation function.""" + # Your custom evaluation logic + # For example, compute F1 score on entities + + metrics = {} + # ... compute metrics ... + metrics["f1"] = 0.85 + metrics["precision"] = 0.87 + metrics["recall"] = 0.83 + + return metrics + +trainer = GLiNER2Trainer( + model=model, + config=config, + compute_metrics=compute_metrics +) + +trainer.train(train_data=examples, eval_data=eval_examples) +``` + +### Loading Checkpoints + +```python +trainer = GLiNER2Trainer(model, config) + +# Load checkpoint (model weights only, no optimizer state) +trainer.load_checkpoint("./output/checkpoint-1000") + +# Continue training (starts fresh with loaded weights) +trainer.train(train_data=examples) +``` + +**Note**: Checkpoints contain model weights only. Training state (optimizer, scheduler) is not saved, so training always starts fresh. + +### Distributed Training + +```python +# Launch with torchrun +# torchrun --nproc_per_node=4 train_script.py + +import os + +config = TrainingConfig( + output_dir="./output", + num_epochs=10, + local_rank=int(os.environ.get("LOCAL_RANK", -1)) # Auto-detect DDP +) + +trainer = GLiNER2Trainer(model, config) +trainer.train(train_data=examples) +``` + +### Learning Rate Schedules + +```python +# Linear warmup + linear decay (default) +config = TrainingConfig( + scheduler_type="linear", + warmup_ratio=0.1 +) + +# Cosine annealing +config = TrainingConfig( + scheduler_type="cosine", + warmup_ratio=0.05 +) + +# Cosine with restarts +config = TrainingConfig( + scheduler_type="cosine_restarts", + warmup_ratio=0.05, + num_cycles=3 +) + +# Constant LR after warmup +config = TrainingConfig( + scheduler_type="constant", + warmup_steps=500 +) +``` + +### Weights & Biases Integration + +```python +config = TrainingConfig( + output_dir="./output", + report_to_wandb=True, + wandb_project="my-gliner-project", + wandb_entity="my-team", + wandb_run_name="experiment-1", + wandb_tags=["ner", "entity-extraction"], + wandb_notes="Testing new architecture" +) + +trainer = GLiNER2Trainer(model, config) +trainer.train(train_data=examples) +# Metrics automatically logged to W&B +``` + +### Data Augmentation + +```python +from gliner2.training.data import InputExample, TrainingDataset + +def augment_example(example: InputExample) -> List[InputExample]: + """Create augmented versions of an example.""" + augmented = [example] # Original + + # Shuffle entity order + if len(example.entities) > 1: + shuffled_entities = dict(sorted(example.entities.items(), reverse=True)) + augmented.append(InputExample( + text=example.text, + entities=shuffled_entities + )) + + return augmented + +# Apply augmentation +dataset = TrainingDataset.load("train.jsonl") +augmented_examples = [] +for ex in dataset: + augmented_examples.extend(augment_example(ex)) + +augmented_dataset = TrainingDataset(augmented_examples) +trainer.train(train_data=augmented_dataset) +``` + +--- + +## Troubleshooting + +### Out of Memory (OOM) + +**Solutions:** +```python +# 1. Reduce batch size +config = TrainingConfig(batch_size=4) + +# 2. Use gradient accumulation +config = TrainingConfig( + batch_size=4, + gradient_accumulation_steps=8 # Effective batch = 32 +) + +# 3. Enable gradient checkpointing +config = TrainingConfig(gradient_checkpointing=True) + +# 4. Use mixed precision +config = TrainingConfig(fp16=True) + +# 5. Use LoRA (most memory efficient) +config = TrainingConfig( + use_lora=True, + lora_r=8, + batch_size=32 # Can use larger batch with LoRA +) + +# 6. Reduce workers +config = TrainingConfig(num_workers=2) +``` + +### Training is Slow + +**Solutions:** +```python +# 1. Increase batch size (if memory allows) +config = TrainingConfig(batch_size=64) + +# 2. Increase workers +config = TrainingConfig(num_workers=8) + +# 3. Use mixed precision +config = TrainingConfig(fp16=True) + +# 4. Reduce evaluation frequency (also reduces checkpoint saves) +config = TrainingConfig( + eval_strategy="steps", + eval_steps=1000 # Evaluate and save every 1000 steps +) + +# 5. Use LoRA (faster training) +config = TrainingConfig(use_lora=True) +``` + +### Validation Errors + +```python +# Check specific errors +dataset = TrainingDataset(examples) +report = dataset.validate(raise_on_error=False) + +print(f"Invalid examples: {report['invalid_indices']}") +for error in report['errors'][:10]: + print(error) + +# Fix common issues: +# 1. Entity not in text +example = InputExample( + text="John works here", + entities={"person": ["John Smith"]} # ERROR: "John Smith" not in text +) +# Fix: Use exact match +example = InputExample( + text="John works here", + entities={"person": ["John"]} # OK +) + +# 2. Empty entities +example = InputExample( + text="Some text", + entities={"person": []} # ERROR: empty list +) +# Fix: Remove empty entity types +example = InputExample( + text="Some text", + entities={} # OK if other tasks present +) + +# 3. Use loose validation during development +dataset.validate(strict=False, raise_on_error=False) +``` + +### Model Not Learning + +**Solutions:** +```python +# 1. Check learning rates +config = TrainingConfig( + encoder_lr=1e-5, # Try: 5e-6, 1e-5, 5e-5 + task_lr=5e-4 # Try: 1e-4, 5e-4, 1e-3 +) + +# 2. Increase training epochs +config = TrainingConfig(num_epochs=20) + +# 3. Check warmup +config = TrainingConfig(warmup_ratio=0.1) + +# 4. Reduce weight decay +config = TrainingConfig(weight_decay=0.001) + +# 5. Try different scheduler +config = TrainingConfig(scheduler_type="cosine") + +# 6. Check data quality +dataset.print_stats() +dataset.validate() +``` + +### LoRA-Specific Issues + +**Issue: LoRA not reducing memory** +```python +# Ensure LoRA is enabled +config = TrainingConfig(use_lora=True) + +# Use smaller rank +config = TrainingConfig(use_lora=True, lora_r=8) + +# Target only encoder (fewer modules = less memory) +config = TrainingConfig( + use_lora=True, + lora_target_modules=["encoder"] +) + +# Target only attention layers (even less memory) +config = TrainingConfig( + use_lora=True, + lora_target_modules=["encoder.query", "encoder.key", "encoder.value"] +) +``` + +**Issue: LoRA performance worse than full fine-tuning** +```python +# Increase rank +config = TrainingConfig(use_lora=True, lora_r=32) + +# Add task heads to target modules +config = TrainingConfig( + use_lora=True, + lora_target_modules=["encoder", "span_rep", "classifier"] +) + +# Target all modules for maximum adaptation +config = TrainingConfig( + use_lora=True, + lora_target_modules=["encoder", "span_rep", "classifier", "count_embed", "count_pred"] +) + +# Increase learning rate +config = TrainingConfig(use_lora=True, task_lr=1e-3) + +# Train longer +config = TrainingConfig(use_lora=True, num_epochs=20) +``` + +--- + +## Best Practices + +1. **Always validate data before training:** + ```python + dataset.validate() + dataset.print_stats() + ``` + +2. **Start with small subset for testing:** + ```python + config = TrainingConfig(max_train_samples=100) + ``` + +3. **Use early stopping for long training:** + ```python + config = TrainingConfig( + early_stopping=True, + early_stopping_patience=5 + ) + ``` + +4. **Save intermediate checkpoints:** + ```python + config = TrainingConfig( + eval_strategy="steps", # Evaluates and saves every N steps + eval_steps=500, + save_best=True + ) + ``` + +5. **Monitor training with W&B:** + ```python + config = TrainingConfig( + report_to_wandb=True, + wandb_project="my-project" + ) + ``` + +6. **Use descriptive entity types and add descriptions:** + ```python + example = InputExample( + text="...", + entities={...}, + entity_descriptions={ + "person": "Full name of a person", + "company": "Business organization name" + } + ) + ``` + +7. **Split your data properly:** + ```python + train, val, test = dataset.split(0.8, 0.1, 0.1) + ``` + +8. **Use appropriate learning rates:** + - Full fine-tuning: Encoder LR `1e-6` to `5e-5` (typically `1e-5`), Task LR `1e-4` to `1e-3` (typically `5e-4`) + - LoRA: Task LR `1e-4` to `1e-3` (typically `5e-4`) + +9. **Consider LoRA for memory-constrained scenarios:** + ```python + config = TrainingConfig( + use_lora=True, + lora_r=16, + lora_alpha=32 + ) + ``` + +10. **Document your experiments:** + ```python + config = TrainingConfig( + experiment_name="v1_medical_ner", + wandb_notes="Testing with LoRA and augmented data" + ) + ``` + +--- + +## Summary + +GLiNER2 provides a flexible and powerful framework for training information extraction models: + +- **Multiple data formats**: JSONL, InputExample, TrainingDataset, raw dicts +- **Four task types**: Entities, Classifications, Structures, Relations +- **Comprehensive validation**: Automatic data validation and statistics +- **Production-ready training**: FP16, gradient accumulation, distributed training +- **LoRA support**: Parameter-efficient fine-tuning with minimal memory usage +- **Extensive configuration**: 40+ config options for fine-grained control +- **Easy to use**: Quick start in 10 lines of code + +Start with the Quick Start examples and gradually explore advanced features as needed!