feat(hugging-face): Add official ecosystem skills
Import the official Hugging Face ecosystem skills and sync the\nexisting local coverage with upstream metadata and assets.\n\nRegenerate the canonical catalog, plugin mirrors, docs, and release\nnotes after the maintainer merge batch so main stays in sync.\n\nFixes #417
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"name": "antigravity-awesome-skills",
|
||||
"version": "9.1.0",
|
||||
"description": "Plugin-safe Claude Code distribution of Antigravity Awesome Skills with 1,318 supported skills.",
|
||||
"description": "Plugin-safe Claude Code distribution of Antigravity Awesome Skills with 1,326 supported skills.",
|
||||
"author": {
|
||||
"name": "sickn33 and contributors",
|
||||
"url": "https://github.com/sickn33/antigravity-awesome-skills"
|
||||
|
||||
28
CATALOG.md
28
CATALOG.md
@@ -2,7 +2,7 @@
|
||||
|
||||
Generated at: 2026-02-08T00:00:00.000Z
|
||||
|
||||
Total skills: 1332
|
||||
Total skills: 1340
|
||||
|
||||
## architecture (88)
|
||||
|
||||
@@ -171,7 +171,7 @@ Total skills: 1332
|
||||
| `warren-buffett` | Agente que simula Warren Buffett — o maior investidor do seculo XX e XXI, CEO da Berkshire Hathaway, discipulo de Benjamin Graham e socio intelectual de Char... | persona, investing, value-investing, business | persona, investing, value-investing, business, warren, buffett, agente, que, simula, maior, investidor, do |
|
||||
| `whatsapp-automation` | Automate WhatsApp Business tasks via Rube MCP (Composio): send messages, manage templates, upload media, and handle contacts. Always search tools first for c... | whatsapp | whatsapp, automation, automate, business, tasks, via, rube, mcp, composio, send, messages, upload |
|
||||
|
||||
## data-ai (251)
|
||||
## data-ai (252)
|
||||
|
||||
| Skill | Description | Tags | Triggers |
|
||||
| --- | --- | --- | --- |
|
||||
@@ -322,6 +322,7 @@ Total skills: 1332
|
||||
| `google-analytics-automation` | Automate Google Analytics tasks via Rube MCP (Composio): run reports, list accounts/properties, funnels, pivots, key events. Always search tools first for cu... | google, analytics | google, analytics, automation, automate, tasks, via, rube, mcp, composio, run, reports, list |
|
||||
| `googlesheets-automation` | Automate Google Sheets operations (read, write, format, filter, manage spreadsheets) via Rube MCP (Composio). Read/write data, manage tabs, apply formatting,... | googlesheets | googlesheets, automation, automate, google, sheets, operations, read, write, format, filter, spreadsheets, via |
|
||||
| `hosted-agents-v2-py` | Build hosted agents using Azure AI Projects SDK with ImageBasedHostedAgentDefinition. Use when creating container-based agents in Azure AI Foundry. | hosted, agents, v2, py | hosted, agents, v2, py, azure, ai, sdk, imagebasedhostedagentdefinition, creating, container, foundry |
|
||||
| `hugging-face-community-evals` | Run local evaluations for Hugging Face Hub models with inspect-ai or lighteval. | hugging, face, community, evals | hugging, face, community, evals, run, local, evaluations, hub, models, inspect, ai, lighteval |
|
||||
| `hugging-face-datasets` | Create and manage datasets on Hugging Face Hub. Supports initializing repos, defining configs/system prompts, streaming row updates, and SQL-based dataset qu... | hugging, face, datasets | hugging, face, datasets, hub, supports, initializing, repos, defining, configs, prompts, streaming, row |
|
||||
| `hybrid-search-implementation` | Combine vector and keyword search for improved retrieval. Use when implementing RAG systems, building search engines, or when neither approach alone provides... | hybrid, search | hybrid, search, combine, vector, keyword, improved, retrieval, implementing, rag, building, engines, neither |
|
||||
| `iconsax-library` | Extensive icon library and AI-driven icon generation skill for premium UI/UX design. | iconsax, library | iconsax, library, extensive, icon, ai, driven, generation, skill, premium, ui, ux |
|
||||
@@ -427,7 +428,7 @@ Total skills: 1332
|
||||
| `youtube-automation` | Automate YouTube tasks via Rube MCP (Composio): upload videos, manage playlists, search content, get analytics, and handle comments. Always search tools firs... | youtube | youtube, automation, automate, tasks, via, rube, mcp, composio, upload, videos, playlists, search |
|
||||
| `zapier-make-patterns` | You are a no-code automation architect who has built thousands of Zaps and Scenarios for businesses of all sizes. You've seen automations that save companies... | zapier, make | zapier, make, no, code, automation, architect, who, built, thousands, zaps, scenarios, businesses |
|
||||
|
||||
## development (182)
|
||||
## development (185)
|
||||
|
||||
| Skill | Description | Tags | Triggers |
|
||||
| --- | --- | --- | --- |
|
||||
@@ -532,8 +533,10 @@ Total skills: 1332
|
||||
| `go-rod-master` | Comprehensive guide for browser automation and web scraping with go-rod (Chrome DevTools Protocol) including stealth anti-bot-detection patterns. | go, rod, master | go, rod, master, browser, automation, web, scraping, chrome, devtools, protocol, including, stealth |
|
||||
| `golang-pro` | Master Go 1.21+ with modern patterns, advanced concurrency, performance optimization, and production-ready microservices. | golang | golang, pro, go, 21, concurrency, performance, optimization, microservices |
|
||||
| `hono` | Build ultra-fast web APIs and full-stack apps with Hono — runs on Cloudflare Workers, Deno, Bun, Node.js, and any WinterCG-compatible runtime. | hono, edge, cloudflare-workers, bun, deno, api, typescript, web-standards | hono, edge, cloudflare-workers, bun, deno, api, typescript, web-standards, ultra, fast, web, apis |
|
||||
| `hugging-face-dataset-viewer` | Use this skill for Hugging Face Dataset Viewer API workflows that fetch subset/split metadata, paginate rows, search text, apply filters, download parquet UR... | hugging, face, dataset, viewer | hugging, face, dataset, viewer, skill, api, fetch, subset, split, metadata, paginate, rows |
|
||||
| `hugging-face-dataset-viewer` | Query Hugging Face datasets through the Dataset Viewer API for splits, rows, search, filters, and parquet links. | hugging, face, dataset, viewer | hugging, face, dataset, viewer, query, datasets, through, api, splits, rows, search, filters |
|
||||
| `hugging-face-evaluation` | Add and manage evaluation results in Hugging Face model cards. Supports extracting eval tables from README content, importing scores from Artificial Analysis... | hugging, face, evaluation | hugging, face, evaluation, add, results, model, cards, supports, extracting, eval, tables, readme |
|
||||
| `hugging-face-gradio` | Build or edit Gradio apps, layouts, components, and chat interfaces in Python. | hugging, face, gradio | hugging, face, gradio, edit, apps, layouts, components, chat, interfaces, python |
|
||||
| `hugging-face-papers` | Read and analyze Hugging Face paper pages or arXiv papers with markdown and papers API metadata. | hugging, face, papers | hugging, face, papers, read, analyze, paper, pages, arxiv, markdown, api, metadata |
|
||||
| `hugging-face-tool-builder` | Your purpose is now is to create reusable command line scripts and utilities for using the Hugging Face API, allowing chaining, piping and intermediate proce... | hugging, face, builder | hugging, face, builder, purpose, now, reusable, command, line, scripts, utilities, api, allowing |
|
||||
| `ios-debugger-agent` | Debug the current iOS project on a booted simulator with XcodeBuildMCP. | ios, debugger, agent | ios, debugger, agent, debug, current, booted, simulator, xcodebuildmcp |
|
||||
| `javascript-mastery` | 33+ essential JavaScript concepts every developer should know, inspired by [33-js-concepts](https://github.com/leonardomso/33-js-concepts). | javascript, mastery | javascript, mastery, 33, essential, concepts, every, developer, should, know, inspired, js, https |
|
||||
@@ -603,6 +606,7 @@ Total skills: 1332
|
||||
| `tavily-web` | Web search, content extraction, crawling, and research capabilities using Tavily API. Use when you need to search the web for current information, extracting... | tavily, web | tavily, web, search, content, extraction, crawling, research, capabilities, api, current, information, extracting |
|
||||
| `telegram` | Integracao completa com Telegram Bot API. Setup com BotFather, mensagens, webhooks, inline keyboards, grupos, canais. Boilerplates Node.js e Python. | messaging, telegram, bots, webhooks | messaging, telegram, bots, webhooks, integracao, completa, com, bot, api, setup, botfather, mensagens |
|
||||
| `temporal-python-testing` | Comprehensive testing approaches for Temporal workflows using pytest, progressive disclosure resources for specific testing scenarios. | temporal, python | temporal, python, testing, approaches, pytest, progressive, disclosure, resources, specific, scenarios |
|
||||
| `transformers-js` | Run Hugging Face models in JavaScript or TypeScript with Transformers.js in Node.js or the browser. | transformers, js | transformers, js, run, hugging, face, models, javascript, typescript, node, browser |
|
||||
| `trigger-dev` | You are a Trigger.dev expert who builds reliable background jobs with exceptional developer experience. You understand that Trigger.dev bridges the gap betwe... | trigger, dev | trigger, dev, who, reliable, background, jobs, exceptional, developer, experience, understand, bridges, gap |
|
||||
| `trpc-fullstack` | Build end-to-end type-safe APIs with tRPC — routers, procedures, middleware, subscriptions, and Next.js/React integration patterns. | typescript, trpc, api, fullstack, nextjs, react, type-safety | typescript, trpc, api, fullstack, nextjs, react, type-safety, type, safe, apis, routers, procedures |
|
||||
| `typescript-advanced-types` | Comprehensive guidance for mastering TypeScript's advanced type system including generics, conditional types, mapped types, template literal types, and utili... | typescript, advanced, types | typescript, advanced, types, guidance, mastering, type, including, generics, conditional, mapped, literal, utility |
|
||||
@@ -614,7 +618,7 @@ Total skills: 1332
|
||||
| `zod-validation-expert` | Expert in Zod — TypeScript-first schema validation. Covers parsing, custom errors, refinements, type inference, and integration with React Hook Form, Next.js... | zod, validation | zod, validation, typescript, first, schema, covers, parsing, custom, errors, refinements, type, inference |
|
||||
| `zustand-store-ts` | Create Zustand stores following established patterns with proper TypeScript types and middleware. | zustand, store, ts | zustand, store, ts, stores, following, established, proper, typescript, types, middleware |
|
||||
|
||||
## general (326)
|
||||
## general (328)
|
||||
|
||||
| Skill | Description | Tags | Triggers |
|
||||
| --- | --- | --- | --- |
|
||||
@@ -774,8 +778,10 @@ Total skills: 1332
|
||||
| `hig-technologies` | Check for .claude/apple-design-context.md before asking questions. Use existing context and only ask for information not already covered. | hig, technologies | hig, technologies, check, claude, apple, context, md, before, asking, questions, existing, ask |
|
||||
| `hosted-agents` | Build background agents in sandboxed environments. Use for hosted coding agents, sandboxed VMs, Modal sandboxes, and remote coding environments. | hosted, agents | hosted, agents, background, sandboxed, environments, coding, vms, modal, sandboxes, remote |
|
||||
| `hubspot-integration` | Authentication for single-account integrations | hubspot, integration | hubspot, integration, authentication, single, account, integrations |
|
||||
| `hugging-face-cli` | The hf CLI provides direct terminal access to the Hugging Face Hub for downloading, uploading, and managing repositories, cache, and compute resources. | hugging, face, cli | hugging, face, cli, hf, provides, direct, terminal, access, hub, downloading, uploading, managing |
|
||||
| `hugging-face-cli` | Use the Hugging Face Hub CLI (`hf`) to download, upload, and manage models, datasets, and Spaces. | hugging, face, cli | hugging, face, cli, hub, hf, download, upload, models, datasets, spaces |
|
||||
| `hugging-face-model-trainer` | Train or fine-tune TRL language models on Hugging Face Jobs, including SFT, DPO, GRPO, and GGUF export. | hugging, face, model, trainer | hugging, face, model, trainer, train, fine, tune, trl, language, models, jobs, including |
|
||||
| `hugging-face-paper-publisher` | Publish and manage research papers on Hugging Face Hub. Supports creating paper pages, linking papers to models/datasets, claiming authorship, and generating... | hugging, face, paper, publisher | hugging, face, paper, publisher, publish, research, papers, hub, supports, creating, pages, linking |
|
||||
| `hugging-face-vision-trainer` | Train or fine-tune vision models on Hugging Face Jobs for detection, classification, and SAM or SAM2 segmentation. | hugging, face, vision, trainer | hugging, face, vision, trainer, train, fine, tune, models, jobs, detection, classification, sam |
|
||||
| `ilya-sutskever` | Agente que simula Ilya Sutskever — co-fundador da OpenAI, ex-Chief Scientist, fundador da SSI. Use quando quiser perspectivas sobre: AGI safety-first, consci... | persona, agi, safety, scaling-laws, openai | persona, agi, safety, scaling-laws, openai, ilya, sutskever, agente, que, simula, co, fundador |
|
||||
| `infinite-gratitude` | Multi-agent research skill for parallel research execution (10 agents, battle-tested with real case studies). | infinite, gratitude | infinite, gratitude, multi, agent, research, skill, parallel, execution, 10, agents, battle, tested |
|
||||
| `inngest` | You are an Inngest expert who builds reliable background processing without managing infrastructure. You understand that serverless doesn't mean you can't ha... | inngest | inngest, who, reliable, background, processing, without, managing, infrastructure, understand, serverless, doesn, mean |
|
||||
@@ -1017,13 +1023,13 @@ Total skills: 1332
|
||||
| `gitops-workflow` | Complete guide to implementing GitOps workflows with ArgoCD and Flux for automated Kubernetes deployments. | gitops | gitops, complete, implementing, argocd, flux, automated, kubernetes, deployments |
|
||||
| `grafana-dashboards` | Create and manage production-ready Grafana dashboards for comprehensive system observability. | grafana, dashboards | grafana, dashboards, observability |
|
||||
| `helm-chart-scaffolding` | Comprehensive guidance for creating, organizing, and managing Helm charts for packaging and deploying Kubernetes applications. | helm, chart | helm, chart, scaffolding, guidance, creating, organizing, managing, charts, packaging, deploying, kubernetes, applications |
|
||||
| `hugging-face-jobs` | Run any workload on fully managed Hugging Face infrastructure. No local setup required—jobs run on cloud CPUs, GPUs, or TPUs and can persist results to the H... | hugging, face, jobs | hugging, face, jobs, run, any, workload, fully, managed, infrastructure, no, local, setup |
|
||||
| `hugging-face-model-trainer` | Train language models using TRL (Transformer Reinforcement Learning) on fully managed Hugging Face infrastructure. No local GPU setup required—models train o... | hugging, face, model, trainer | hugging, face, model, trainer, train, language, models, trl, transformer, reinforcement, learning, fully |
|
||||
| `hugging-face-trackio` | Track ML experiments with Trackio using Python logging, alerts, and CLI metric retrieval. | hugging, face, trackio | hugging, face, trackio, track, ml, experiments, python, logging, alerts, cli, metric, retrieval |
|
||||
| `hybrid-cloud-architect` | Expert hybrid cloud architect specializing in complex multi-cloud solutions across AWS/Azure/GCP and private clouds (OpenStack/VMware). | hybrid, cloud | hybrid, cloud, architect, specializing, complex, multi, solutions, aws, azure, gcp, private, clouds |
|
||||
| `hybrid-cloud-networking` | Configure secure, high-performance connectivity between on-premises and cloud environments using VPN, Direct Connect, and ExpressRoute. | hybrid, cloud, networking | hybrid, cloud, networking, configure, secure, high, performance, connectivity, between, premises, environments, vpn |
|
||||
| `istio-traffic-management` | Comprehensive guide to Istio traffic management for production service mesh deployments. | istio, traffic | istio, traffic, mesh, deployments |
|
||||
| `iterate-pr` | Iterate on a PR until CI passes. Use when you need to fix CI failures, address review feedback, or continuously push fixes until all checks are green. Automa... | iterate, pr | iterate, pr, until, ci, passes, fix, failures, address, review, feedback, continuously, push |
|
||||
| `java-pro` | Master Java 21+ with modern features like virtual threads, pattern matching, and Spring Boot 3.x. Expert in the latest Java ecosystem including GraalVM, Proj... | java | java, pro, 21, features, like, virtual, threads, matching, spring, boot, latest, ecosystem |
|
||||
| `jq` | Expert jq usage for JSON querying, filtering, transformation, and pipeline integration. Practical patterns for real shell workflows. | jq, json, shell, cli, data-transformation, bash | jq, json, shell, cli, data-transformation, bash, usage, querying, filtering, transformation, pipeline, integration |
|
||||
| `k6-load-testing` | Comprehensive k6 load testing skill for API, browser, and scalability testing. Write realistic load scenarios, analyze results, and integrate with CI/CD. | k6, load-testing, performance, api-testing, ci-cd | k6, load-testing, performance, api-testing, ci-cd, load, testing, skill, api, browser, scalability, write |
|
||||
| `kubernetes-architect` | Expert Kubernetes architect specializing in cloud-native infrastructure, advanced GitOps workflows (ArgoCD/Flux), and enterprise container orchestration. | kubernetes | kubernetes, architect, specializing, cloud, native, infrastructure, gitops, argocd, flux, enterprise, container, orchestration |
|
||||
| `kubernetes-deployment` | Kubernetes deployment workflow for container orchestration, Helm charts, service mesh, and production-ready K8s configurations. | kubernetes, deployment | kubernetes, deployment, container, orchestration, helm, charts, mesh, k8s, configurations |
|
||||
@@ -1069,7 +1075,7 @@ Total skills: 1332
|
||||
| `whatsapp-cloud-api` | Integracao com WhatsApp Business Cloud API (Meta). Mensagens, templates, webhooks HMAC-SHA256, automacao de atendimento. Boilerplates Node.js e Python. | messaging, whatsapp, meta, webhooks | messaging, whatsapp, meta, webhooks, cloud, api, integracao, com, business, mensagens, hmac, sha256 |
|
||||
| `x-twitter-scraper` | X (Twitter) data platform skill — tweet search, user lookup, follower extraction, engagement metrics, giveaway draws, monitoring, webhooks, 19 extraction too... | [twitter, x-api, scraping, mcp, social-media, data-extraction, giveaway, monitoring, webhooks] | [twitter, x-api, scraping, mcp, social-media, data-extraction, giveaway, monitoring, webhooks], twitter, scraper, data |
|
||||
|
||||
## security (166)
|
||||
## security (167)
|
||||
|
||||
| Skill | Description | Tags | Triggers |
|
||||
| --- | --- | --- | --- |
|
||||
@@ -1139,6 +1145,7 @@ Total skills: 1332
|
||||
| `graphql-architect` | Master modern GraphQL with federation, performance optimization, and enterprise security. Build scalable schemas, implement advanced caching, and design real... | graphql | graphql, architect, federation, performance, optimization, enterprise, security, scalable, schemas, caching, real, time |
|
||||
| `grpc-golang` | Build production-ready gRPC services in Go with mTLS, streaming, and observability. Use when designing Protobuf contracts with Buf or implementing secure ser... | grpc, golang | grpc, golang, go, mtls, streaming, observability, designing, protobuf, contracts, buf, implementing, secure |
|
||||
| `html-injection-testing` | Identify and exploit HTML injection vulnerabilities that allow attackers to inject malicious HTML content into web applications. This vulnerability enables a... | html, injection | html, injection, testing, identify, exploit, vulnerabilities, allow, attackers, inject, malicious, content, web |
|
||||
| `hugging-face-jobs` | Run workloads on Hugging Face Jobs with managed CPUs, GPUs, TPUs, secrets, and Hub persistence. | hugging, face, jobs | hugging, face, jobs, run, workloads, managed, cpus, gpus, tpus, secrets, hub, persistence |
|
||||
| `incident-responder` | Expert SRE incident responder specializing in rapid problem resolution, modern observability, and comprehensive incident management. | incident, responder | incident, responder, sre, specializing, rapid, problem, resolution, observability |
|
||||
| `incident-response-incident-response` | Use when working with incident response incident response | incident, response | incident, response, working |
|
||||
| `incident-response-smart-fix` | [Extended thinking: This workflow implements a sophisticated debugging and resolution pipeline that leverages AI-assisted debugging tools and observability p... | incident, response, fix | incident, response, fix, smart, extended, thinking, implements, sophisticated, debugging, resolution, pipeline, leverages |
|
||||
@@ -1276,7 +1283,7 @@ Total skills: 1332
|
||||
| `wiki-qa` | Answer repository questions grounded entirely in source code evidence. Use when user asks a question about the codebase, user wants to understand a specific ... | wiki, qa | wiki, qa, answer, repository, questions, grounded, entirely, source, code, evidence, user, asks |
|
||||
| `windows-privilege-escalation` | Provide systematic methodologies for discovering and exploiting privilege escalation vulnerabilities on Windows systems during penetration testing engagements. | windows, privilege, escalation | windows, privilege, escalation, provide, systematic, methodologies, discovering, exploiting, vulnerabilities, during, penetration, testing |
|
||||
|
||||
## workflow (100)
|
||||
## workflow (101)
|
||||
|
||||
| Skill | Description | Tags | Triggers |
|
||||
| --- | --- | --- | --- |
|
||||
@@ -1367,6 +1374,7 @@ Total skills: 1332
|
||||
| `team-collaboration-issue` | You are a GitHub issue resolution expert specializing in systematic bug investigation, feature implementation, and collaborative development workflows. Your ... | team, collaboration, issue | team, collaboration, issue, github, resolution, specializing, systematic, bug, investigation, feature, collaborative, development |
|
||||
| `telegram-automation` | Automate Telegram tasks via Rube MCP (Composio): send messages, manage chats, share photos/documents, and handle bot commands. Always search tools first for ... | telegram | telegram, automation, automate, tasks, via, rube, mcp, composio, send, messages, chats, share |
|
||||
| `tiktok-automation` | Automate TikTok tasks via Rube MCP (Composio): upload/publish videos, post photos, manage content, and view user profiles/stats. Always search tools first fo... | tiktok | tiktok, automation, automate, tasks, via, rube, mcp, composio, upload, publish, videos, post |
|
||||
| `tmux` | Expert tmux session, window, and pane management for terminal multiplexing, persistent remote workflows, and shell scripting automation. | tmux, terminal, multiplexer, sessions, shell, remote, automation | tmux, terminal, multiplexer, sessions, shell, remote, automation, session, window, pane, multiplexing, persistent |
|
||||
| `todoist-automation` | Automate Todoist task management, projects, sections, filtering, and bulk operations via Rube MCP (Composio). Always search tools first for current schemas. | todoist | todoist, automation, automate, task, sections, filtering, bulk, operations, via, rube, mcp, composio |
|
||||
| `track-management` | Use this skill when creating, managing, or working with Conductor tracks - the logical work units for features, bugs, and refactors. Applies to spec.md, plan... | track | track, skill, creating, managing, working, conductor, tracks, logical, work, units, features, bugs |
|
||||
| `trello-automation` | Automate Trello boards, cards, and workflows via Rube MCP (Composio). Create cards, manage lists, assign members, and search across boards programmatically. | trello | trello, automation, automate, boards, cards, via, rube, mcp, composio, lists, assign, members |
|
||||
|
||||
47
CHANGELOG.md
47
CHANGELOG.md
@@ -9,6 +9,53 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [9.2.0] - 2026-03-29 - "Hugging Face Ecosystem and Shell Workflow Expansion"
|
||||
|
||||
> Installable skill library update for Claude Code, Cursor, Codex CLI, Gemini CLI, Antigravity, and related AI coding assistants.
|
||||
|
||||
Start here:
|
||||
|
||||
- Install: `npx antigravity-awesome-skills`
|
||||
- Choose your tool: [README -> Choose Your Tool](https://github.com/sickn33/antigravity-awesome-skills#choose-your-tool)
|
||||
- Best skills by tool: [README -> Best Skills By Tool](https://github.com/sickn33/antigravity-awesome-skills#best-skills-by-tool)
|
||||
- Bundles: [docs/users/bundles.md](https://github.com/sickn33/antigravity-awesome-skills/blob/main/docs/users/bundles.md)
|
||||
- Workflows: [docs/users/workflows.md](https://github.com/sickn33/antigravity-awesome-skills/blob/main/docs/users/workflows.md)
|
||||
|
||||
This release expands practical day-to-day coverage for Claude Code, Cursor, Codex CLI, Gemini CLI, and similar agent workflows. It adds a full batch of Hugging Face ecosystem skills, new shell and terminal expertise for `jq` and `tmux`, a new `viboscope` collaboration skill, and stronger Odoo guidance for safer credentials and more reliable EDI flows.
|
||||
|
||||
## New Skills
|
||||
|
||||
- **hugging-face-community-evals** - run local Hugging Face Hub model evaluations with `inspect-ai` and `lighteval`.
|
||||
- **hugging-face-gradio** - build and edit Gradio demos, layouts, and chat interfaces in Python.
|
||||
- **hugging-face-papers** - read and analyze Hugging Face paper pages and arXiv-linked metadata.
|
||||
- **hugging-face-trackio** - track ML experiments with Trackio logging, alerts, and CLI metric retrieval.
|
||||
- **hugging-face-vision-trainer** - train and fine-tune detection, classification, and SAM or SAM2 vision models on Hugging Face Jobs.
|
||||
- **transformers-js** - run Hugging Face models in JavaScript and TypeScript with Transformers.js.
|
||||
- **jq** - add expert JSON querying, transformation, and shell pipeline guidance for terminal-first workflows (PR #414).
|
||||
- **tmux** - add advanced session, pane, scripting, and remote terminal workflow guidance (PR #414).
|
||||
- **viboscope** - add psychological compatibility matching guidance for cofounder, collaborator, and relationship discovery workflows (PR #415).
|
||||
|
||||
## Improvements
|
||||
|
||||
- **Hugging Face official skill sync** - refreshed local Hugging Face coverage and attribution for `hugging-face-cli`, `hugging-face-dataset-viewer`, `hugging-face-jobs`, `hugging-face-model-trainer`, and `hugging-face-paper-publisher`, while packaging the missing official ecosystem skills into the repo.
|
||||
- **Odoo security hardening** - merged safer credential handling for `odoo-woocommerce-bridge` by replacing hardcoded secrets with environment-variable lookups (PR #413).
|
||||
- **Odoo EDI resilience** - improved `odoo-edi-connector` with idempotency checks, partner verification, dynamic X12 date handling, and safer environment-based configuration (PR #416).
|
||||
- **Maintainer and release docs** - folded in the latest maintainer guidance around risk-label sync, repo-state hygiene, and release/CI workflow consistency.
|
||||
|
||||
## Who should care
|
||||
|
||||
- **Claude Code, Codex CLI, Cursor, and Gemini CLI users** get broader Hugging Face ecosystem coverage for datasets, Jobs, evaluations, papers, Trackio, and Transformers.js workflows.
|
||||
- **Terminal-heavy developers and infra teams** get stronger `jq` and `tmux` guidance for JSON processing, session management, and scripted shell workflows.
|
||||
- **Odoo integrators** get safer bridge examples and more production-ready EDI connector patterns.
|
||||
- **Builders looking for collaborator-matching workflows** get a new `viboscope` skill for compatibility-driven discovery.
|
||||
|
||||
## Credits
|
||||
|
||||
- **[@kostakost2](https://github.com/kostakost2)** for the new `jq` and `tmux` skills in PR #414
|
||||
- **[@ivankoriako](https://github.com/ivankoriako)** for the new `viboscope` skill in PR #415
|
||||
- **[@Champbreed](https://github.com/Champbreed)** for Odoo security and EDI improvements in PRs #413 and #416
|
||||
- **[Hugging Face](https://github.com/huggingface/skills)** for the upstream official skill collection synced into this release
|
||||
|
||||
### Changed
|
||||
|
||||
- **Risk maintenance workflow**: expanded the legacy `risk:` cleanup flow so maintainers can sync explicit high-confidence `none`, `safe`, `critical`, and `offensive` labels from audit suggestions, including auto-inserting the required `AUTHORIZED USE ONLY` notice when a legacy skill is promoted to `offensive`.
|
||||
|
||||
31
README.md
31
README.md
@@ -1,7 +1,7 @@
|
||||
<!-- registry-sync: version=9.1.0; skills=1332; stars=28053; updated_at=2026-03-28T15:48:03+00:00 -->
|
||||
# 🌌 Antigravity Awesome Skills: 1,332+ Agentic Skills for Claude Code, Gemini CLI, Cursor, Copilot & More
|
||||
<!-- registry-sync: version=9.1.0; skills=1340; stars=28053; updated_at=2026-03-28T15:48:03+00:00 -->
|
||||
# 🌌 Antigravity Awesome Skills: 1,340+ Agentic Skills for Claude Code, Gemini CLI, Cursor, Copilot & More
|
||||
|
||||
> **Installable GitHub library of 1,332+ agentic skills for Claude Code, Cursor, Codex CLI, Gemini CLI, Antigravity, and other AI coding assistants.**
|
||||
> **Installable GitHub library of 1,340+ agentic skills for Claude Code, Cursor, Codex CLI, Gemini CLI, Antigravity, and other AI coding assistants.**
|
||||
|
||||
Antigravity Awesome Skills is a GitHub repository and installer CLI for reusable `SKILL.md` playbooks. Instead of collecting random prompts, you get a searchable, installable skill library for planning, coding, debugging, testing, security review, infrastructure work, product workflows, and growth tasks across the major AI coding assistants.
|
||||
|
||||
@@ -26,7 +26,7 @@ Antigravity Awesome Skills is a GitHub repository and installer CLI for reusable
|
||||
|
||||
- **Installable, not just inspirational**: use `npx antigravity-awesome-skills` to put skills where your tool expects them.
|
||||
- **Built for major agent workflows**: Claude Code, Cursor, Codex CLI, Gemini CLI, Antigravity, Kiro, OpenCode, Copilot, and more.
|
||||
- **Broad coverage with real utility**: 1,332+ skills across development, testing, security, infrastructure, product, and marketing.
|
||||
- **Broad coverage with real utility**: 1,340+ skills across development, testing, security, infrastructure, product, and marketing.
|
||||
- **Faster onboarding**: bundles and workflows reduce the time from "I found this repo" to "I used my first skill".
|
||||
- **Useful whether you want breadth or curation**: browse the full catalog, start with top bundles, or compare alternatives before installing.
|
||||
|
||||
@@ -47,7 +47,7 @@ Antigravity Awesome Skills is a GitHub repository and installer CLI for reusable
|
||||
- [🧭 Antigravity Workflows](#antigravity-workflows)
|
||||
- [⚖️ Alternatives & Comparisons](#alternatives--comparisons)
|
||||
- [📦 Features & Categories](#features--categories)
|
||||
- [📚 Browse 1,332+ Skills](#browse-1332-skills)
|
||||
- [📚 Browse 1,340+ Skills](#browse-1340-skills)
|
||||
- [🤝 Contributing](#contributing)
|
||||
- [💬 Community](#community)
|
||||
- [☕ Support the Project](#support-the-project)
|
||||
@@ -261,7 +261,7 @@ If you want the full explanation of root plugins, bundle plugins, full-library i
|
||||
|
||||
## Best Skills By Tool
|
||||
|
||||
If you want a faster answer than "browse all 1,332+ skills", start with a tool-specific guide:
|
||||
If you want a faster answer than "browse all 1,340+ skills", start with a tool-specific guide:
|
||||
|
||||
- **[Claude Code skills](docs/users/claude-code-skills.md)**: install paths, starter skills, prompt examples, and plugin marketplace flow.
|
||||
- **[Cursor skills](docs/users/cursor-skills.md)**: best starter skills for `.cursor/skills/`, UI-heavy work, and pair-programming flows.
|
||||
@@ -428,7 +428,7 @@ The repository is organized into specialized domains to transform your AI into a
|
||||
|
||||
Counts change as new skills are added. For the current full registry, see [CATALOG.md](CATALOG.md).
|
||||
|
||||
## Browse 1,332+ Skills
|
||||
## Browse 1,340+ Skills
|
||||
|
||||
- Open the interactive browser in [`apps/web-app`](apps/web-app).
|
||||
- Read the full catalog in [`CATALOG.md`](CATALOG.md).
|
||||
@@ -651,12 +651,12 @@ We officially thank the following contributors for their help in making this rep
|
||||
- [@Mohammad-Faiz-Cloud-Engineer](https://github.com/Mohammad-Faiz-Cloud-Engineer)
|
||||
- [@zinzied](https://github.com/zinzied)
|
||||
- [@ssumanbiswas](https://github.com/ssumanbiswas)
|
||||
- [@Champbreed](https://github.com/Champbreed)
|
||||
- [@Dokhacgiakhoa](https://github.com/Dokhacgiakhoa)
|
||||
- [@sx4im](https://github.com/sx4im)
|
||||
- [@maxdml](https://github.com/maxdml)
|
||||
- [@IanJ332](https://github.com/IanJ332)
|
||||
- [@maxdml](https://github.com/maxdml)
|
||||
- [@skyruh](https://github.com/skyruh)
|
||||
- [@Champbreed](https://github.com/Champbreed)
|
||||
- [@ar27111994](https://github.com/ar27111994)
|
||||
- [@chauey](https://github.com/chauey)
|
||||
- [@itsmeares](https://github.com/itsmeares)
|
||||
@@ -681,13 +681,13 @@ We officially thank the following contributors for their help in making this rep
|
||||
- [@fernandezbaptiste](https://github.com/fernandezbaptiste)
|
||||
- [@Gizzant](https://github.com/Gizzant)
|
||||
- [@JayeHarrill](https://github.com/JayeHarrill)
|
||||
- [@AssassinMaeve](https://github.com/AssassinMaeve)
|
||||
- [@Tiger-Foxx](https://github.com/Tiger-Foxx)
|
||||
- [@RamonRiosJr](https://github.com/RamonRiosJr)
|
||||
- [@Musayrlsms](https://github.com/Musayrlsms)
|
||||
- [@AssassinMaeve](https://github.com/AssassinMaeve)
|
||||
- [@vuth-dogo](https://github.com/vuth-dogo)
|
||||
- [@Wittlesus](https://github.com/Wittlesus)
|
||||
- [@wahidzzz](https://github.com/wahidzzz)
|
||||
- [@yubing744](https://github.com/yubing744)
|
||||
- [@Vonfry](https://github.com/Vonfry)
|
||||
- [@vprudnikoff](https://github.com/vprudnikoff)
|
||||
- [@viktor-ferenczi](https://github.com/viktor-ferenczi)
|
||||
@@ -699,7 +699,6 @@ We officially thank the following contributors for their help in making this rep
|
||||
- [@TomGranot](https://github.com/TomGranot)
|
||||
- [@terryspitz](https://github.com/terryspitz)
|
||||
- [@Onsraa](https://github.com/Onsraa)
|
||||
- [@PabloASMD](https://github.com/PabloASMD)
|
||||
- [@SebConejo](https://github.com/SebConejo)
|
||||
- [@SuperJMN](https://github.com/SuperJMN)
|
||||
- [@Enreign](https://github.com/Enreign)
|
||||
@@ -710,7 +709,8 @@ We officially thank the following contributors for their help in making this rep
|
||||
- [@ronanguilloux](https://github.com/ronanguilloux)
|
||||
- [@sraphaz](https://github.com/sraphaz)
|
||||
- [@ProgramadorBrasil](https://github.com/ProgramadorBrasil)
|
||||
- [@vuth-dogo](https://github.com/vuth-dogo)
|
||||
- [@PabloASMD](https://github.com/PabloASMD)
|
||||
- [@yubing744](https://github.com/yubing744)
|
||||
- [@yang1002378395-cmyk](https://github.com/yang1002378395-cmyk)
|
||||
- [@viliawang-pm](https://github.com/viliawang-pm)
|
||||
- [@uucz](https://github.com/uucz)
|
||||
@@ -737,8 +737,7 @@ We officially thank the following contributors for their help in making this rep
|
||||
- [@ziuus](https://github.com/ziuus)
|
||||
- [@Cerdore](https://github.com/Cerdore)
|
||||
- [@Wolfe-Jam](https://github.com/Wolfe-Jam)
|
||||
- [@olgasafonova](https://github.com/olgasafonova)
|
||||
- [@ivankoriako](https://github.com/ivankoriako)
|
||||
- [@qcwssss](https://github.com/qcwssss)
|
||||
- [@rcigor](https://github.com/rcigor)
|
||||
- [@hvasconcelos](https://github.com/hvasconcelos)
|
||||
- [@Guilherme-ruy](https://github.com/Guilherme-ruy)
|
||||
@@ -764,6 +763,7 @@ We officially thank the following contributors for their help in making this rep
|
||||
- [@ALEKGG1](https://github.com/ALEKGG1)
|
||||
- [@8144225309](https://github.com/8144225309)
|
||||
- [@1bcMax](https://github.com/1bcMax)
|
||||
- [@olgasafonova](https://github.com/olgasafonova)
|
||||
- [@sharmanilay](https://github.com/sharmanilay)
|
||||
- [@KhaiTrang1995](https://github.com/KhaiTrang1995)
|
||||
- [@LocNguyenSGU](https://github.com/LocNguyenSGU)
|
||||
@@ -789,7 +789,6 @@ We officially thank the following contributors for their help in making this rep
|
||||
- [@Jonohobs](https://github.com/Jonohobs)
|
||||
- [@JaskiratAnand](https://github.com/JaskiratAnand)
|
||||
- [@jamescha-earley](https://github.com/jamescha-earley)
|
||||
- [@qcwssss](https://github.com/qcwssss)
|
||||
|
||||
## License
|
||||
|
||||
|
||||
@@ -12,6 +12,12 @@
|
||||
<changefreq>weekly</changefreq>
|
||||
<priority>0.7</priority>
|
||||
</url>
|
||||
<url>
|
||||
<loc>http://localhost/skill/jq</loc>
|
||||
<lastmod>2026-03-29</lastmod>
|
||||
<changefreq>weekly</changefreq>
|
||||
<priority>0.7</priority>
|
||||
</url>
|
||||
<url>
|
||||
<loc>http://localhost/skill/phase-gated-debugging</loc>
|
||||
<lastmod>2026-03-29</lastmod>
|
||||
@@ -24,6 +30,12 @@
|
||||
<changefreq>weekly</changefreq>
|
||||
<priority>0.7</priority>
|
||||
</url>
|
||||
<url>
|
||||
<loc>http://localhost/skill/tmux</loc>
|
||||
<lastmod>2026-03-29</lastmod>
|
||||
<changefreq>weekly</changefreq>
|
||||
<priority>0.7</priority>
|
||||
</url>
|
||||
<url>
|
||||
<loc>http://localhost/skill/akf-trust-metadata</loc>
|
||||
<lastmod>2026-03-29</lastmod>
|
||||
@@ -234,16 +246,4 @@
|
||||
<changefreq>weekly</changefreq>
|
||||
<priority>0.7</priority>
|
||||
</url>
|
||||
<url>
|
||||
<loc>http://localhost/skill/obsidian-markdown</loc>
|
||||
<lastmod>2026-03-29</lastmod>
|
||||
<changefreq>weekly</changefreq>
|
||||
<priority>0.7</priority>
|
||||
</url>
|
||||
<url>
|
||||
<loc>http://localhost/skill/product-marketing-context</loc>
|
||||
<lastmod>2026-03-29</lastmod>
|
||||
<changefreq>weekly</changefreq>
|
||||
<priority>0.7</priority>
|
||||
</url>
|
||||
</urlset>
|
||||
|
||||
@@ -14598,10 +14598,32 @@
|
||||
"path": "skills/hugging-face-cli",
|
||||
"category": "ai-ml",
|
||||
"name": "hugging-face-cli",
|
||||
"description": "The hf CLI provides direct terminal access to the Hugging Face Hub for downloading, uploading, and managing repositories, cache, and compute resources.",
|
||||
"risk": "safe",
|
||||
"source": "https://github.com/huggingface/skills/tree/main/skills/hugging-face-cli",
|
||||
"date_added": "2026-02-27",
|
||||
"description": "Use the Hugging Face Hub CLI (`hf`) to download, upload, and manage models, datasets, and Spaces.",
|
||||
"risk": "unknown",
|
||||
"source": "https://github.com/huggingface/skills/tree/main/skills/hf-cli",
|
||||
"date_added": null,
|
||||
"plugin": {
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-community-evals",
|
||||
"path": "skills/hugging-face-community-evals",
|
||||
"category": "ai-ml",
|
||||
"name": "hugging-face-community-evals",
|
||||
"description": "Run local evaluations for Hugging Face Hub models with inspect-ai or lighteval.",
|
||||
"risk": "unknown",
|
||||
"source": "https://github.com/huggingface/skills/tree/main/skills/huggingface-community-evals",
|
||||
"date_added": null,
|
||||
"plugin": {
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
@@ -14620,9 +14642,9 @@
|
||||
"path": "skills/hugging-face-dataset-viewer",
|
||||
"category": "ai-ml",
|
||||
"name": "hugging-face-dataset-viewer",
|
||||
"description": "Use this skill for Hugging Face Dataset Viewer API workflows that fetch subset/split metadata, paginate rows, search text, apply filters, download parquet URLs, and read size or statistics.",
|
||||
"description": "Query Hugging Face datasets through the Dataset Viewer API for splits, rows, search, filters, and parquet links.",
|
||||
"risk": "unknown",
|
||||
"source": "community",
|
||||
"source": "https://github.com/huggingface/skills/tree/main/skills/huggingface-datasets",
|
||||
"date_added": null,
|
||||
"plugin": {
|
||||
"targets": {
|
||||
@@ -14681,15 +14703,37 @@
|
||||
"reasons": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-gradio",
|
||||
"path": "skills/hugging-face-gradio",
|
||||
"category": "ai-ml",
|
||||
"name": "hugging-face-gradio",
|
||||
"description": "Build or edit Gradio apps, layouts, components, and chat interfaces in Python.",
|
||||
"risk": "unknown",
|
||||
"source": "https://github.com/huggingface/skills/tree/main/skills/huggingface-gradio",
|
||||
"date_added": null,
|
||||
"plugin": {
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-jobs",
|
||||
"path": "skills/hugging-face-jobs",
|
||||
"category": "ai-ml",
|
||||
"name": "hugging-face-jobs",
|
||||
"description": "Run any workload on fully managed Hugging Face infrastructure. No local setup required\u2014jobs run on cloud CPUs, GPUs, or TPUs and can persist results to the Hugging Face Hub.",
|
||||
"risk": "safe",
|
||||
"source": "https://github.com/huggingface/skills/tree/main/skills/hugging-face-jobs",
|
||||
"date_added": "2026-02-27",
|
||||
"description": "Run workloads on Hugging Face Jobs with managed CPUs, GPUs, TPUs, secrets, and Hub persistence.",
|
||||
"risk": "unknown",
|
||||
"source": "https://github.com/huggingface/skills/tree/main/skills/huggingface-jobs",
|
||||
"date_added": null,
|
||||
"plugin": {
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
@@ -14708,9 +14752,9 @@
|
||||
"path": "skills/hugging-face-model-trainer",
|
||||
"category": "ai-ml",
|
||||
"name": "hugging-face-model-trainer",
|
||||
"description": "Train language models using TRL (Transformer Reinforcement Learning) on fully managed Hugging Face infrastructure. No local GPU setup required\u2014models train on cloud GPUs and results are automatically saved to the Hugging Face Hub.",
|
||||
"description": "Train or fine-tune TRL language models on Hugging Face Jobs, including SFT, DPO, GRPO, and GGUF export.",
|
||||
"risk": "unknown",
|
||||
"source": "community",
|
||||
"source": "https://github.com/huggingface/skills/tree/main/skills/huggingface-llm-trainer",
|
||||
"date_added": null,
|
||||
"plugin": {
|
||||
"targets": {
|
||||
@@ -14732,7 +14776,29 @@
|
||||
"name": "hugging-face-paper-publisher",
|
||||
"description": "Publish and manage research papers on Hugging Face Hub. Supports creating paper pages, linking papers to models/datasets, claiming authorship, and generating professional markdown-based research articles.",
|
||||
"risk": "unknown",
|
||||
"source": "community",
|
||||
"source": "https://github.com/huggingface/skills/tree/main/skills/huggingface-paper-publisher",
|
||||
"date_added": null,
|
||||
"plugin": {
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-papers",
|
||||
"path": "skills/hugging-face-papers",
|
||||
"category": "ai-ml",
|
||||
"name": "hugging-face-papers",
|
||||
"description": "Read and analyze Hugging Face paper pages or arXiv papers with markdown and papers API metadata.",
|
||||
"risk": "unknown",
|
||||
"source": "https://github.com/huggingface/skills/tree/main/skills/huggingface-papers",
|
||||
"date_added": null,
|
||||
"plugin": {
|
||||
"targets": {
|
||||
@@ -14769,6 +14835,50 @@
|
||||
"reasons": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-trackio",
|
||||
"path": "skills/hugging-face-trackio",
|
||||
"category": "ai-ml",
|
||||
"name": "hugging-face-trackio",
|
||||
"description": "Track ML experiments with Trackio using Python logging, alerts, and CLI metric retrieval.",
|
||||
"risk": "unknown",
|
||||
"source": "https://github.com/huggingface/skills/tree/main/skills/huggingface-trackio",
|
||||
"date_added": null,
|
||||
"plugin": {
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-vision-trainer",
|
||||
"path": "skills/hugging-face-vision-trainer",
|
||||
"category": "ai-ml",
|
||||
"name": "hugging-face-vision-trainer",
|
||||
"description": "Train or fine-tune vision models on Hugging Face Jobs for detection, classification, and SAM or SAM2 segmentation.",
|
||||
"risk": "unknown",
|
||||
"source": "https://github.com/huggingface/skills/tree/main/skills/huggingface-vision-trainer",
|
||||
"date_added": null,
|
||||
"plugin": {
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "hybrid-cloud-architect",
|
||||
"path": "skills/hybrid-cloud-architect",
|
||||
@@ -15583,6 +15693,28 @@
|
||||
"reasons": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "jq",
|
||||
"path": "skills/jq",
|
||||
"category": "development",
|
||||
"name": "jq",
|
||||
"description": "Expert jq usage for JSON querying, filtering, transformation, and pipeline integration. Practical patterns for real shell workflows.",
|
||||
"risk": "safe",
|
||||
"source": "community",
|
||||
"date_added": "2026-03-28",
|
||||
"plugin": {
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "json-canvas",
|
||||
"path": "skills/json-canvas",
|
||||
@@ -26897,6 +27029,28 @@
|
||||
"reasons": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "tmux",
|
||||
"path": "skills/tmux",
|
||||
"category": "development",
|
||||
"name": "tmux",
|
||||
"description": "Expert tmux session, window, and pane management for terminal multiplexing, persistent remote workflows, and shell scripting automation.",
|
||||
"risk": "safe",
|
||||
"source": "community",
|
||||
"date_added": "2026-03-28",
|
||||
"plugin": {
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "todoist-automation",
|
||||
"path": "skills/todoist-automation",
|
||||
@@ -27007,6 +27161,28 @@
|
||||
"reasons": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "transformers-js",
|
||||
"path": "skills/transformers-js",
|
||||
"category": "web-development",
|
||||
"name": "transformers-js",
|
||||
"description": "Run Hugging Face models in JavaScript or TypeScript with Transformers.js in Node.js or the browser.",
|
||||
"risk": "unknown",
|
||||
"source": "https://github.com/huggingface/skills/tree/main/skills/transformers-js",
|
||||
"date_added": null,
|
||||
"plugin": {
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "travel-health-analyzer",
|
||||
"path": "skills/travel-health-analyzer",
|
||||
|
||||
@@ -80,6 +80,7 @@
|
||||
"vr-ar": "game-development/vr-ar",
|
||||
"web-games": "game-development/web-games",
|
||||
"git-pr-workflow": "git-pr-workflows-git-workflow",
|
||||
"hugging-face-evals": "hugging-face-community-evals",
|
||||
"hugging-face-publisher": "hugging-face-paper-publisher",
|
||||
"incident-response": "incident-response-incident-response",
|
||||
"javascript-typescript-scaffold": "javascript-typescript-typescript-scaffold",
|
||||
|
||||
@@ -183,7 +183,10 @@
|
||||
"hono",
|
||||
"hugging-face-dataset-viewer",
|
||||
"hugging-face-evaluation",
|
||||
"hugging-face-gradio",
|
||||
"hugging-face-papers",
|
||||
"hugging-face-tool-builder",
|
||||
"hugging-face-trackio",
|
||||
"instagram",
|
||||
"ios-debugger-agent",
|
||||
"ios-developer",
|
||||
@@ -288,6 +291,7 @@
|
||||
"temporal-golang-pro",
|
||||
"temporal-python-pro",
|
||||
"temporal-python-testing",
|
||||
"transformers-js",
|
||||
"trigger-dev",
|
||||
"trpc-fullstack",
|
||||
"typescript-advanced-types",
|
||||
@@ -371,6 +375,7 @@
|
||||
"gha-security-review",
|
||||
"graphql-architect",
|
||||
"html-injection-testing",
|
||||
"hugging-face-jobs",
|
||||
"k8s-manifest-generator",
|
||||
"k8s-security-policies",
|
||||
"laravel-expert",
|
||||
@@ -691,11 +696,13 @@
|
||||
"gitops-workflow",
|
||||
"grafana-dashboards",
|
||||
"grpc-golang",
|
||||
"hugging-face-trackio",
|
||||
"incident-responder",
|
||||
"incident-response-incident-response",
|
||||
"incident-response-smart-fix",
|
||||
"incident-runbook-templates",
|
||||
"internal-comms",
|
||||
"jq",
|
||||
"kubernetes-architect",
|
||||
"kubernetes-deployment",
|
||||
"langfuse",
|
||||
@@ -906,6 +913,7 @@
|
||||
"testing-patterns",
|
||||
"testing-qa",
|
||||
"tiktok-automation",
|
||||
"tmux",
|
||||
"todoist-automation",
|
||||
"trello-automation",
|
||||
"twitter-automation",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"generatedAt": "2026-02-08T00:00:00.000Z",
|
||||
"total": 1332,
|
||||
"total": 1340,
|
||||
"skills": [
|
||||
{
|
||||
"id": "00-andruia-consultant",
|
||||
@@ -16279,7 +16279,7 @@
|
||||
{
|
||||
"id": "hugging-face-cli",
|
||||
"name": "hugging-face-cli",
|
||||
"description": "The hf CLI provides direct terminal access to the Hugging Face Hub for downloading, uploading, and managing repositories, cache, and compute resources.",
|
||||
"description": "Use the Hugging Face Hub CLI (`hf`) to download, upload, and manage models, datasets, and Spaces.",
|
||||
"category": "general",
|
||||
"tags": [
|
||||
"hugging",
|
||||
@@ -16290,22 +16290,47 @@
|
||||
"hugging",
|
||||
"face",
|
||||
"cli",
|
||||
"hf",
|
||||
"provides",
|
||||
"direct",
|
||||
"terminal",
|
||||
"access",
|
||||
"hub",
|
||||
"downloading",
|
||||
"uploading",
|
||||
"managing"
|
||||
"hf",
|
||||
"download",
|
||||
"upload",
|
||||
"models",
|
||||
"datasets",
|
||||
"spaces"
|
||||
],
|
||||
"path": "skills/hugging-face-cli/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-community-evals",
|
||||
"name": "hugging-face-community-evals",
|
||||
"description": "Run local evaluations for Hugging Face Hub models with inspect-ai or lighteval.",
|
||||
"category": "data-ai",
|
||||
"tags": [
|
||||
"hugging",
|
||||
"face",
|
||||
"community",
|
||||
"evals"
|
||||
],
|
||||
"triggers": [
|
||||
"hugging",
|
||||
"face",
|
||||
"community",
|
||||
"evals",
|
||||
"run",
|
||||
"local",
|
||||
"evaluations",
|
||||
"hub",
|
||||
"models",
|
||||
"inspect",
|
||||
"ai",
|
||||
"lighteval"
|
||||
],
|
||||
"path": "skills/hugging-face-community-evals/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-dataset-viewer",
|
||||
"name": "hugging-face-dataset-viewer",
|
||||
"description": "Use this skill for Hugging Face Dataset Viewer API workflows that fetch subset/split metadata, paginate rows, search text, apply filters, download parquet URLs, and read size or statistics.",
|
||||
"description": "Query Hugging Face datasets through the Dataset Viewer API for splits, rows, search, filters, and parquet links.",
|
||||
"category": "development",
|
||||
"tags": [
|
||||
"hugging",
|
||||
@@ -16318,14 +16343,14 @@
|
||||
"face",
|
||||
"dataset",
|
||||
"viewer",
|
||||
"skill",
|
||||
"query",
|
||||
"datasets",
|
||||
"through",
|
||||
"api",
|
||||
"fetch",
|
||||
"subset",
|
||||
"split",
|
||||
"metadata",
|
||||
"paginate",
|
||||
"rows"
|
||||
"splits",
|
||||
"rows",
|
||||
"search",
|
||||
"filters"
|
||||
],
|
||||
"path": "skills/hugging-face-dataset-viewer/SKILL.md"
|
||||
},
|
||||
@@ -16381,11 +16406,35 @@
|
||||
],
|
||||
"path": "skills/hugging-face-evaluation/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-gradio",
|
||||
"name": "hugging-face-gradio",
|
||||
"description": "Build or edit Gradio apps, layouts, components, and chat interfaces in Python.",
|
||||
"category": "development",
|
||||
"tags": [
|
||||
"hugging",
|
||||
"face",
|
||||
"gradio"
|
||||
],
|
||||
"triggers": [
|
||||
"hugging",
|
||||
"face",
|
||||
"gradio",
|
||||
"edit",
|
||||
"apps",
|
||||
"layouts",
|
||||
"components",
|
||||
"chat",
|
||||
"interfaces",
|
||||
"python"
|
||||
],
|
||||
"path": "skills/hugging-face-gradio/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-jobs",
|
||||
"name": "hugging-face-jobs",
|
||||
"description": "Run any workload on fully managed Hugging Face infrastructure. No local setup required—jobs run on cloud CPUs, GPUs, or TPUs and can persist results to the Hugging Face Hub.",
|
||||
"category": "infrastructure",
|
||||
"description": "Run workloads on Hugging Face Jobs with managed CPUs, GPUs, TPUs, secrets, and Hub persistence.",
|
||||
"category": "security",
|
||||
"tags": [
|
||||
"hugging",
|
||||
"face",
|
||||
@@ -16396,22 +16445,22 @@
|
||||
"face",
|
||||
"jobs",
|
||||
"run",
|
||||
"any",
|
||||
"workload",
|
||||
"fully",
|
||||
"workloads",
|
||||
"managed",
|
||||
"infrastructure",
|
||||
"no",
|
||||
"local",
|
||||
"setup"
|
||||
"cpus",
|
||||
"gpus",
|
||||
"tpus",
|
||||
"secrets",
|
||||
"hub",
|
||||
"persistence"
|
||||
],
|
||||
"path": "skills/hugging-face-jobs/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-model-trainer",
|
||||
"name": "hugging-face-model-trainer",
|
||||
"description": "Train language models using TRL (Transformer Reinforcement Learning) on fully managed Hugging Face infrastructure. No local GPU setup required—models train on cloud GPUs and results are automatically saved to the Hugging Face Hub.",
|
||||
"category": "infrastructure",
|
||||
"description": "Train or fine-tune TRL language models on Hugging Face Jobs, including SFT, DPO, GRPO, and GGUF export.",
|
||||
"category": "general",
|
||||
"tags": [
|
||||
"hugging",
|
||||
"face",
|
||||
@@ -16424,13 +16473,13 @@
|
||||
"model",
|
||||
"trainer",
|
||||
"train",
|
||||
"fine",
|
||||
"tune",
|
||||
"trl",
|
||||
"language",
|
||||
"models",
|
||||
"trl",
|
||||
"transformer",
|
||||
"reinforcement",
|
||||
"learning",
|
||||
"fully"
|
||||
"jobs",
|
||||
"including"
|
||||
],
|
||||
"path": "skills/hugging-face-model-trainer/SKILL.md"
|
||||
},
|
||||
@@ -16461,6 +16510,31 @@
|
||||
],
|
||||
"path": "skills/hugging-face-paper-publisher/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-papers",
|
||||
"name": "hugging-face-papers",
|
||||
"description": "Read and analyze Hugging Face paper pages or arXiv papers with markdown and papers API metadata.",
|
||||
"category": "development",
|
||||
"tags": [
|
||||
"hugging",
|
||||
"face",
|
||||
"papers"
|
||||
],
|
||||
"triggers": [
|
||||
"hugging",
|
||||
"face",
|
||||
"papers",
|
||||
"read",
|
||||
"analyze",
|
||||
"paper",
|
||||
"pages",
|
||||
"arxiv",
|
||||
"markdown",
|
||||
"api",
|
||||
"metadata"
|
||||
],
|
||||
"path": "skills/hugging-face-papers/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-tool-builder",
|
||||
"name": "hugging-face-tool-builder",
|
||||
@@ -16487,6 +16561,59 @@
|
||||
],
|
||||
"path": "skills/hugging-face-tool-builder/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-trackio",
|
||||
"name": "hugging-face-trackio",
|
||||
"description": "Track ML experiments with Trackio using Python logging, alerts, and CLI metric retrieval.",
|
||||
"category": "infrastructure",
|
||||
"tags": [
|
||||
"hugging",
|
||||
"face",
|
||||
"trackio"
|
||||
],
|
||||
"triggers": [
|
||||
"hugging",
|
||||
"face",
|
||||
"trackio",
|
||||
"track",
|
||||
"ml",
|
||||
"experiments",
|
||||
"python",
|
||||
"logging",
|
||||
"alerts",
|
||||
"cli",
|
||||
"metric",
|
||||
"retrieval"
|
||||
],
|
||||
"path": "skills/hugging-face-trackio/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-vision-trainer",
|
||||
"name": "hugging-face-vision-trainer",
|
||||
"description": "Train or fine-tune vision models on Hugging Face Jobs for detection, classification, and SAM or SAM2 segmentation.",
|
||||
"category": "general",
|
||||
"tags": [
|
||||
"hugging",
|
||||
"face",
|
||||
"vision",
|
||||
"trainer"
|
||||
],
|
||||
"triggers": [
|
||||
"hugging",
|
||||
"face",
|
||||
"vision",
|
||||
"trainer",
|
||||
"train",
|
||||
"fine",
|
||||
"tune",
|
||||
"models",
|
||||
"jobs",
|
||||
"detection",
|
||||
"classification",
|
||||
"sam"
|
||||
],
|
||||
"path": "skills/hugging-face-vision-trainer/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "hybrid-cloud-architect",
|
||||
"name": "hybrid-cloud-architect",
|
||||
@@ -17339,6 +17466,35 @@
|
||||
],
|
||||
"path": "skills/jobgpt/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "jq",
|
||||
"name": "jq",
|
||||
"description": "Expert jq usage for JSON querying, filtering, transformation, and pipeline integration. Practical patterns for real shell workflows.",
|
||||
"category": "infrastructure",
|
||||
"tags": [
|
||||
"jq",
|
||||
"json",
|
||||
"shell",
|
||||
"cli",
|
||||
"data-transformation",
|
||||
"bash"
|
||||
],
|
||||
"triggers": [
|
||||
"jq",
|
||||
"json",
|
||||
"shell",
|
||||
"cli",
|
||||
"data-transformation",
|
||||
"bash",
|
||||
"usage",
|
||||
"querying",
|
||||
"filtering",
|
||||
"transformation",
|
||||
"pipeline",
|
||||
"integration"
|
||||
],
|
||||
"path": "skills/jq/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "json-canvas",
|
||||
"name": "json-canvas",
|
||||
@@ -29983,6 +30139,36 @@
|
||||
],
|
||||
"path": "skills/tiktok-automation/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "tmux",
|
||||
"name": "tmux",
|
||||
"description": "Expert tmux session, window, and pane management for terminal multiplexing, persistent remote workflows, and shell scripting automation.",
|
||||
"category": "workflow",
|
||||
"tags": [
|
||||
"tmux",
|
||||
"terminal",
|
||||
"multiplexer",
|
||||
"sessions",
|
||||
"shell",
|
||||
"remote",
|
||||
"automation"
|
||||
],
|
||||
"triggers": [
|
||||
"tmux",
|
||||
"terminal",
|
||||
"multiplexer",
|
||||
"sessions",
|
||||
"shell",
|
||||
"remote",
|
||||
"automation",
|
||||
"session",
|
||||
"window",
|
||||
"pane",
|
||||
"multiplexing",
|
||||
"persistent"
|
||||
],
|
||||
"path": "skills/tmux/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "todoist-automation",
|
||||
"name": "todoist-automation",
|
||||
@@ -30108,6 +30294,29 @@
|
||||
],
|
||||
"path": "skills/track-management/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "transformers-js",
|
||||
"name": "transformers-js",
|
||||
"description": "Run Hugging Face models in JavaScript or TypeScript with Transformers.js in Node.js or the browser.",
|
||||
"category": "development",
|
||||
"tags": [
|
||||
"transformers",
|
||||
"js"
|
||||
],
|
||||
"triggers": [
|
||||
"transformers",
|
||||
"js",
|
||||
"run",
|
||||
"hugging",
|
||||
"face",
|
||||
"models",
|
||||
"javascript",
|
||||
"typescript",
|
||||
"node",
|
||||
"browser"
|
||||
],
|
||||
"path": "skills/transformers-js/SKILL.md"
|
||||
},
|
||||
{
|
||||
"id": "travel-health-analyzer",
|
||||
"name": "travel-health-analyzer",
|
||||
|
||||
@@ -12656,6 +12656,25 @@
|
||||
},
|
||||
"runtime_files": []
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-community-evals",
|
||||
"path": "skills/hugging-face-community-evals",
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": [],
|
||||
"blocked_reasons": {
|
||||
"codex": [],
|
||||
"claude": []
|
||||
},
|
||||
"runtime_files": []
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-dataset-viewer",
|
||||
"path": "skills/hugging-face-dataset-viewer",
|
||||
@@ -12713,6 +12732,25 @@
|
||||
},
|
||||
"runtime_files": []
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-gradio",
|
||||
"path": "skills/hugging-face-gradio",
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": [],
|
||||
"blocked_reasons": {
|
||||
"codex": [],
|
||||
"claude": []
|
||||
},
|
||||
"runtime_files": []
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-jobs",
|
||||
"path": "skills/hugging-face-jobs",
|
||||
@@ -12770,6 +12808,25 @@
|
||||
},
|
||||
"runtime_files": []
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-papers",
|
||||
"path": "skills/hugging-face-papers",
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": [],
|
||||
"blocked_reasons": {
|
||||
"codex": [],
|
||||
"claude": []
|
||||
},
|
||||
"runtime_files": []
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-tool-builder",
|
||||
"path": "skills/hugging-face-tool-builder",
|
||||
@@ -12789,6 +12846,44 @@
|
||||
},
|
||||
"runtime_files": []
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-trackio",
|
||||
"path": "skills/hugging-face-trackio",
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": [],
|
||||
"blocked_reasons": {
|
||||
"codex": [],
|
||||
"claude": []
|
||||
},
|
||||
"runtime_files": []
|
||||
},
|
||||
{
|
||||
"id": "hugging-face-vision-trainer",
|
||||
"path": "skills/hugging-face-vision-trainer",
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": [],
|
||||
"blocked_reasons": {
|
||||
"codex": [],
|
||||
"claude": []
|
||||
},
|
||||
"runtime_files": []
|
||||
},
|
||||
{
|
||||
"id": "hybrid-cloud-architect",
|
||||
"path": "skills/hybrid-cloud-architect",
|
||||
@@ -13473,6 +13568,25 @@
|
||||
},
|
||||
"runtime_files": []
|
||||
},
|
||||
{
|
||||
"id": "jq",
|
||||
"path": "skills/jq",
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": [],
|
||||
"blocked_reasons": {
|
||||
"codex": [],
|
||||
"claude": []
|
||||
},
|
||||
"runtime_files": []
|
||||
},
|
||||
{
|
||||
"id": "json-canvas",
|
||||
"path": "skills/json-canvas",
|
||||
@@ -23387,6 +23501,25 @@
|
||||
},
|
||||
"runtime_files": []
|
||||
},
|
||||
{
|
||||
"id": "tmux",
|
||||
"path": "skills/tmux",
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": [],
|
||||
"blocked_reasons": {
|
||||
"codex": [],
|
||||
"claude": []
|
||||
},
|
||||
"runtime_files": []
|
||||
},
|
||||
{
|
||||
"id": "todoist-automation",
|
||||
"path": "skills/todoist-automation",
|
||||
@@ -23482,6 +23615,25 @@
|
||||
},
|
||||
"runtime_files": []
|
||||
},
|
||||
{
|
||||
"id": "transformers-js",
|
||||
"path": "skills/transformers-js",
|
||||
"targets": {
|
||||
"codex": "supported",
|
||||
"claude": "supported"
|
||||
},
|
||||
"setup": {
|
||||
"type": "none",
|
||||
"summary": "",
|
||||
"docs": null
|
||||
},
|
||||
"reasons": [],
|
||||
"blocked_reasons": {
|
||||
"codex": [],
|
||||
"claude": []
|
||||
},
|
||||
"runtime_files": []
|
||||
},
|
||||
{
|
||||
"id": "travel-health-analyzer",
|
||||
"path": "skills/travel-health-analyzer",
|
||||
@@ -25470,10 +25622,10 @@
|
||||
}
|
||||
],
|
||||
"summary": {
|
||||
"total_skills": 1332,
|
||||
"total_skills": 1340,
|
||||
"supported": {
|
||||
"codex": 1303,
|
||||
"claude": 1318
|
||||
"codex": 1311,
|
||||
"claude": 1326
|
||||
},
|
||||
"blocked": {
|
||||
"codex": 29,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
---
|
||||
title: Jetski/Cortex + Gemini Integration Guide
|
||||
description: "Come usare antigravity-awesome-skills con Jetski/Cortex evitando l’overflow di contesto con 1.332+ skill."
|
||||
description: "Come usare antigravity-awesome-skills con Jetski/Cortex evitando l’overflow di contesto con 1.340+ skill."
|
||||
---
|
||||
|
||||
# Jetski/Cortex + Gemini: integrazione sicura con 1.332+ skill
|
||||
# Jetski/Cortex + Gemini: integrazione sicura con 1.340+ skill
|
||||
|
||||
Questa guida mostra come integrare il repository `antigravity-awesome-skills` con un agente basato su **Jetski/Cortex + Gemini** (o framework simili) **senza superare il context window** del modello.
|
||||
|
||||
@@ -23,7 +23,7 @@ Non bisogna mai:
|
||||
- concatenare il contenuto di tutte le `SKILL.md` in un singolo system prompt;
|
||||
- reiniettare l’intera libreria per **ogni** richiesta.
|
||||
|
||||
Con oltre 1.332 skill, questo approccio riempie il context window prima ancora di aggiungere i messaggi dell’utente, causando l’errore di truncation.
|
||||
Con oltre 1.340 skill, questo approccio riempie il context window prima ancora di aggiungere i messaggi dell’utente, causando l’errore di truncation.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ This example shows one way to integrate **antigravity-awesome-skills** with a Je
|
||||
- How to enforce a **maximum number of skills per turn** via `maxSkillsPerTurn`.
|
||||
- How to choose whether to **truncate or error** when too many skills are requested via `overflowBehavior`.
|
||||
|
||||
This pattern avoids context overflow when you have 1,332+ skills installed.
|
||||
This pattern avoids context overflow when you have 1,340+ skills installed.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ This document keeps the repository's GitHub-facing discovery copy aligned with t
|
||||
|
||||
Preferred positioning:
|
||||
|
||||
> Installable GitHub library of 1,332+ agentic skills for Claude Code, Cursor, Codex CLI, Gemini CLI, Antigravity, and other AI coding assistants.
|
||||
> Installable GitHub library of 1,340+ agentic skills for Claude Code, Cursor, Codex CLI, Gemini CLI, Antigravity, and other AI coding assistants.
|
||||
|
||||
Key framing:
|
||||
|
||||
@@ -20,7 +20,7 @@ Key framing:
|
||||
|
||||
Preferred description:
|
||||
|
||||
> Installable GitHub library of 1,332+ agentic skills for Claude Code, Cursor, Codex CLI, Gemini CLI, Antigravity, and more. Includes installer CLI, bundles, workflows, and official/community skill collections.
|
||||
> Installable GitHub library of 1,340+ agentic skills for Claude Code, Cursor, Codex CLI, Gemini CLI, Antigravity, and more. Includes installer CLI, bundles, workflows, and official/community skill collections.
|
||||
|
||||
Preferred homepage:
|
||||
|
||||
@@ -28,7 +28,7 @@ Preferred homepage:
|
||||
|
||||
Preferred social preview:
|
||||
|
||||
- use a clean preview image that says `1,332+ Agentic Skills`;
|
||||
- use a clean preview image that says `1,340+ Agentic Skills`;
|
||||
- mention Claude Code, Cursor, Codex CLI, and Gemini CLI;
|
||||
- avoid dense text and tiny logos that disappear in social cards.
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ For manual updates, you need:
|
||||
The update process refreshes:
|
||||
- Skills index (`skills_index.json`)
|
||||
- Web app skills data (`apps\web-app\public\skills.json`)
|
||||
- All 1,332+ skills from the skills directory
|
||||
- All 1,340+ skills from the skills directory
|
||||
|
||||
## When to Update
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ The following skills were added from the curated collection at [VoltAgent/awesom
|
||||
| :------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------ | :--------- | :--------------------------------- |
|
||||
| `vercel-deploy-claimable` | [Vercel Labs](https://github.com/vercel-labs/agent-skills) | MIT | Official Vercel skill |
|
||||
| `design-md` | [Google Labs (Stitch)](https://github.com/google-labs-code/stitch-skills) | Compatible | Google Labs Stitch skills |
|
||||
| `hugging-face-cli`, `hugging-face-jobs` | [Hugging Face](https://github.com/huggingface/skills) | Compatible | Official Hugging Face skills |
|
||||
| `hugging-face-cli`, `hugging-face-community-evals`, `hugging-face-dataset-viewer`, `hugging-face-gradio`, `hugging-face-jobs`, `hugging-face-model-trainer`, `hugging-face-paper-publisher`, `hugging-face-papers`, `hugging-face-trackio`, `hugging-face-vision-trainer`, `transformers-js` | [Hugging Face](https://github.com/huggingface/skills) | Compatible | Official Hugging Face skills |
|
||||
| `culture-index`, `fix-review`, `sharp-edges` | [Trail of Bits](https://github.com/trailofbits/skills) | Compatible | Security skills from Trail of Bits |
|
||||
| `expo-deployment`, `upgrading-expo` | [Expo](https://github.com/expo/skills) | Compatible | Official Expo skills |
|
||||
| `commit`, `create-pr`, `find-bugs`, `iterate-pr` | [Sentry](https://github.com/getsentry/skills) | Compatible | Sentry dev team skills |
|
||||
@@ -118,7 +118,7 @@ The following skills were added during the March 2026 skills update:
|
||||
### Machine Learning & Data
|
||||
| Skill | Source | License | Notes |
|
||||
|-------|--------|---------|-------|
|
||||
| `hugging-face-dataset-viewer`, `hugging-face-datasets`, `hugging-face-evaluation`, `hugging-face-model-trainer`, `hugging-face-paper-publisher`, `hugging-face-tool-builder` | [huggingface/skills](https://github.com/huggingface/skills) | Compatible | HuggingFace ML tools |
|
||||
| `hugging-face-datasets`, `hugging-face-evaluation`, `hugging-face-tool-builder` | [huggingface/skills](https://github.com/huggingface/skills) | Compatible | Hugging Face ecosystem extensions |
|
||||
| `numpy`, `pandas`, `scipy`, `matplotlib`, `scikit-learn`, `jupyter-workflow` | [K-Dense-AI/claude-scientific-skills](https://github.com/K-Dense-AI/claude-scientific-skills) | Compatible | Data science essentials |
|
||||
| `biopython`, `scanpy`, `uniprot-database`, `pubmed-database` | [K-Dense-AI/claude-scientific-skills](https://github.com/K-Dense-AI/claude-scientific-skills) | Compatible | Bioinformatics tools |
|
||||
|
||||
|
||||
@@ -673,4 +673,4 @@ Found a skill that should be in a bundle? Or want to create a new bundle? [Open
|
||||
|
||||
---
|
||||
|
||||
_Last updated: March 2026 | Total Skills: 1,332+ | Total Bundles: 37_
|
||||
_Last updated: March 2026 | Total Skills: 1,340+ | Total Bundles: 37_
|
||||
|
||||
@@ -12,7 +12,7 @@ Install the library into Claude Code, then invoke focused skills directly in the
|
||||
|
||||
## Why use this repo for Claude Code
|
||||
|
||||
- It includes 1,332+ skills instead of a narrow single-domain starter pack.
|
||||
- It includes 1,340+ skills instead of a narrow single-domain starter pack.
|
||||
- It supports the standard `.claude/skills/` path and the Claude Code plugin marketplace flow.
|
||||
- It also ships generated bundle plugins so teams can install focused packs like `Essentials` or `Security Developer` from the marketplace metadata.
|
||||
- It includes onboarding docs, bundles, and workflows so new users do not need to guess where to begin.
|
||||
|
||||
@@ -12,7 +12,7 @@ Install into the Gemini skills path, then ask Gemini to apply one skill at a tim
|
||||
|
||||
- It installs directly into the expected Gemini skills path.
|
||||
- It includes both core software engineering skills and deeper agent/LLM-oriented skills.
|
||||
- It helps new users get started with bundles and workflows rather than forcing a cold start from 1,332+ files.
|
||||
- It helps new users get started with bundles and workflows rather than forcing a cold start from 1,340+ files.
|
||||
- It is useful whether you want a broad internal skill library or a single repo to test many workflows quickly.
|
||||
|
||||
## Install Gemini CLI Skills
|
||||
|
||||
@@ -18,7 +18,7 @@ Kiro is AWS's agentic AI IDE that combines:
|
||||
|
||||
Kiro's agentic capabilities are enhanced by skills that provide:
|
||||
|
||||
- **Domain expertise** across 1,332+ specialized areas
|
||||
- **Domain expertise** across 1,340+ specialized areas
|
||||
- **Best practices** from Anthropic, OpenAI, Google, Microsoft, and AWS
|
||||
- **Workflow automation** for common development tasks
|
||||
- **AWS-specific patterns** for serverless, infrastructure, and cloud architecture
|
||||
|
||||
@@ -14,7 +14,7 @@ If you came in through a **Claude Code** or **Codex** plugin instead of a full l
|
||||
|
||||
When you ran `npx antigravity-awesome-skills` or cloned the repository, you:
|
||||
|
||||
✅ **Downloaded 1,332+ skill files** to your computer (default: `~/.gemini/antigravity/skills/`; or a custom path like `~/.agent/skills/` if you used `--path`)
|
||||
✅ **Downloaded 1,340+ skill files** to your computer (default: `~/.gemini/antigravity/skills/`; or a custom path like `~/.agent/skills/` if you used `--path`)
|
||||
✅ **Made them available** to your AI assistant
|
||||
❌ **Did NOT enable them all automatically** (they're just sitting there, waiting)
|
||||
|
||||
@@ -34,7 +34,7 @@ Bundles are **curated groups** of skills organized by role. They help you decide
|
||||
|
||||
**Analogy:**
|
||||
|
||||
- You installed a toolbox with 1,332+ tools (✅ done)
|
||||
- You installed a toolbox with 1,340+ tools (✅ done)
|
||||
- Bundles are like **labeled organizer trays** saying: "If you're a carpenter, start with these 10 tools"
|
||||
- You can either **pick skills from the tray** or install that tray as a focused marketplace bundle plugin
|
||||
|
||||
@@ -212,7 +212,7 @@ Let's actually use a skill right now. Follow these steps:
|
||||
|
||||
## Step 5: Picking Your First Skills (Practical Advice)
|
||||
|
||||
Don't try to use all 1,332+ skills at once. Here's a sensible approach:
|
||||
Don't try to use all 1,340+ skills at once. Here's a sensible approach:
|
||||
|
||||
If you want a tool-specific starting point before choosing skills, use:
|
||||
|
||||
@@ -343,7 +343,7 @@ Usually no, but if your AI doesn't recognize a skill:
|
||||
|
||||
### "Can I load all skills into the model at once?"
|
||||
|
||||
No. Even though you have 1,332+ skills installed locally, you should **not** concatenate every `SKILL.md` into a single system prompt or context block.
|
||||
No. Even though you have 1,340+ skills installed locally, you should **not** concatenate every `SKILL.md` into a single system prompt or context block.
|
||||
|
||||
The intended pattern is:
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ antigravity-awesome-skills/
|
||||
├── 📄 CONTRIBUTING.md ← Contributor workflow
|
||||
├── 📄 CATALOG.md ← Full generated catalog
|
||||
│
|
||||
├── 📁 skills/ ← 1,332+ skills live here
|
||||
├── 📁 skills/ ← 1,340+ skills live here
|
||||
│ │
|
||||
│ ├── 📁 brainstorming/
|
||||
│ │ └── 📄 SKILL.md ← Skill definition
|
||||
@@ -47,7 +47,7 @@ antigravity-awesome-skills/
|
||||
│ │ └── 📁 2d-games/
|
||||
│ │ └── 📄 SKILL.md ← Nested skills also supported
|
||||
│ │
|
||||
│ └── ... (1,332+ total)
|
||||
│ └── ... (1,340+ total)
|
||||
│
|
||||
├── 📁 apps/
|
||||
│ └── 📁 web-app/ ← Interactive browser
|
||||
@@ -100,7 +100,7 @@ antigravity-awesome-skills/
|
||||
|
||||
```
|
||||
┌─────────────────────────┐
|
||||
│ 1,332+ SKILLS │
|
||||
│ 1,340+ SKILLS │
|
||||
└────────────┬────────────┘
|
||||
│
|
||||
┌────────────────────────┼────────────────────────┐
|
||||
@@ -201,7 +201,7 @@ If you want a workspace-style manual install instead, cloning into `.agent/skill
|
||||
│ ├── 📁 brainstorming/ │
|
||||
│ ├── 📁 stripe-integration/ │
|
||||
│ ├── 📁 react-best-practices/ │
|
||||
│ └── ... (1,332+ total) │
|
||||
│ └── ... (1,340+ total) │
|
||||
└─────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
| :------------------------------------------------------------------------------------------ | :------------------------------------------------------------------------ | :--------- | :--------------------------------- |
|
||||
| `vercel-deploy-claimable` | [Vercel Labs](https://github.com/vercel-labs/agent-skills) | MIT | 官方Vercel技能 |
|
||||
| `design-md` | [Google Labs (Stitch)](https://github.com/google-labs-code/stitch-skills) | Compatible | Google Labs Stitch技能 |
|
||||
| `hugging-face-cli`, `hugging-face-jobs` | [Hugging Face](https://github.com/huggingface/skills) | Compatible | 官方Hugging Face技能 |
|
||||
| `hugging-face-cli`, `hugging-face-community-evals`, `hugging-face-dataset-viewer`, `hugging-face-gradio`, `hugging-face-jobs`, `hugging-face-model-trainer`, `hugging-face-paper-publisher`, `hugging-face-papers`, `hugging-face-trackio`, `hugging-face-vision-trainer`, `transformers-js` | [Hugging Face](https://github.com/huggingface/skills) | Compatible | 官方Hugging Face技能 |
|
||||
| `culture-index`, `fix-review`, `sharp-edges` | [Trail of Bits](https://github.com/trailofbits/skills) | Compatible | Trail of Bits安全技能 |
|
||||
| `expo-deployment`, `upgrading-expo` | [Expo](https://github.com/expo/skills) | Compatible | 官方Expo技能 |
|
||||
| `commit`, `create-pr`, `find-bugs`, `iterate-pr` | [Sentry](https://github.com/getsentry/skills) | Compatible | Sentry开发团队技能 |
|
||||
@@ -116,7 +116,7 @@
|
||||
### 机器学习与数据
|
||||
| 技能 | 来源 | 许可证 | 备注 |
|
||||
|-------|--------|---------|-------|
|
||||
| `hugging-face-dataset-viewer`, `hugging-face-datasets`, `hugging-face-evaluation`, `hugging-face-model-trainer`, `hugging-face-paper-publisher`, `hugging-face-tool-builder` | [huggingface/skills](https://github.com/huggingface/skills) | Compatible | HuggingFace ML工具 |
|
||||
| `hugging-face-datasets`, `hugging-face-evaluation`, `hugging-face-tool-builder` | [huggingface/skills](https://github.com/huggingface/skills) | Compatible | Hugging Face 生态扩展 |
|
||||
| `numpy`, `pandas`, `scipy`, `matplotlib`, `scikit-learn`, `jupyter-workflow` | [K-Dense-AI/claude-scientific-skills](https://github.com/K-Dense-AI/claude-scientific-skills) | Compatible | 数据科学基础 |
|
||||
| `biopython`, `scanpy`, `uniprot-database`, `pubmed-database` | [K-Dense-AI/claude-scientific-skills](https://github.com/K-Dense-AI/claude-scientific-skills) | Compatible | 生物信息学工具 |
|
||||
|
||||
@@ -234,4 +234,4 @@
|
||||
- 🐛 [GitHub Issues](https://github.com/sickn33/antigravity-awesome-skills/issues)
|
||||
- 💬 [GitHub Discussions](https://github.com/sickn33/antigravity-awesome-skills/discussions)
|
||||
|
||||
我们致力于确保每个贡献都得到适当的认可和尊重。感谢您为AI开发社区做出的贡献!🙏
|
||||
我们致力于确保每个贡献都得到适当的认可和尊重。感谢您为AI开发社区做出的贡献!🙏
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"name": "antigravity-awesome-skills",
|
||||
"version": "9.1.0",
|
||||
"description": "1,332+ agentic skills for Claude Code, Gemini CLI, Cursor, Antigravity & more. Installer CLI.",
|
||||
"description": "1,340+ agentic skills for Claude Code, Gemini CLI, Cursor, Antigravity & more. Installer CLI.",
|
||||
"license": "MIT",
|
||||
"scripts": {
|
||||
"validate": "node tools/scripts/run-python.js tools/scripts/validate_skills.py",
|
||||
|
||||
@@ -1,199 +1,194 @@
|
||||
---
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/hf-cli"
|
||||
name: hugging-face-cli
|
||||
description: "The hf CLI provides direct terminal access to the Hugging Face Hub for downloading, uploading, and managing repositories, cache, and compute resources."
|
||||
risk: safe
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/hugging-face-cli"
|
||||
date_added: "2026-02-27"
|
||||
description: "Use the Hugging Face Hub CLI (`hf`) to download, upload, and manage models, datasets, and Spaces."
|
||||
risk: unknown
|
||||
---
|
||||
|
||||
# Hugging Face CLI
|
||||
Install: `curl -LsSf https://hf.co/cli/install.sh | bash -s`.
|
||||
|
||||
The `hf` CLI provides direct terminal access to the Hugging Face Hub for downloading, uploading, and managing repositories, cache, and compute resources.
|
||||
## When to Use
|
||||
|
||||
## When to Use This Skill
|
||||
Use this skill when you need the `hf` CLI for Hub authentication, downloads, uploads, repo management, or basic compute operations.
|
||||
|
||||
Use this skill when:
|
||||
- User needs to download models, datasets, or spaces
|
||||
- Uploading files to Hub repositories
|
||||
- Creating Hugging Face repositories
|
||||
- Managing local cache
|
||||
- Running compute jobs on HF infrastructure
|
||||
- Working with Hugging Face Hub authentication
|
||||
The Hugging Face Hub CLI tool `hf` is available. IMPORTANT: The `hf` command replaces the deprecated `huggingface-cli` command.
|
||||
|
||||
## Quick Command Reference
|
||||
Use `hf --help` to view available functions. Note that auth commands are now all under `hf auth` e.g. `hf auth whoami`.
|
||||
|
||||
| Task | Command |
|
||||
|------|---------|
|
||||
| Login | `hf auth login` |
|
||||
| Download model | `hf download <repo_id>` |
|
||||
| Download to folder | `hf download <repo_id> --local-dir ./path` |
|
||||
| Upload folder | `hf upload <repo_id> . .` |
|
||||
| Create repo | `hf repo create <name>` |
|
||||
| Create tag | `hf repo tag create <repo_id> <tag>` |
|
||||
| Delete files | `hf repo-files delete <repo_id> <files>` |
|
||||
| List cache | `hf cache ls` |
|
||||
| Remove from cache | `hf cache rm <repo_or_revision>` |
|
||||
| List models | `hf models ls` |
|
||||
| Get model info | `hf models info <model_id>` |
|
||||
| List datasets | `hf datasets ls` |
|
||||
| Get dataset info | `hf datasets info <dataset_id>` |
|
||||
| List spaces | `hf spaces ls` |
|
||||
| Get space info | `hf spaces info <space_id>` |
|
||||
| List endpoints | `hf endpoints ls` |
|
||||
| Run GPU job | `hf jobs run --flavor a10g-small <image> <cmd>` |
|
||||
| Environment info | `hf env` |
|
||||
Generated with `huggingface_hub v1.8.0`. Run `hf skills add --force` to regenerate.
|
||||
|
||||
## Core Commands
|
||||
## Commands
|
||||
|
||||
### Authentication
|
||||
```bash
|
||||
hf auth login # Interactive login
|
||||
hf auth login --token $HF_TOKEN # Non-interactive
|
||||
hf auth whoami # Check current user
|
||||
hf auth list # List stored tokens
|
||||
hf auth switch # Switch between tokens
|
||||
hf auth logout # Log out
|
||||
```
|
||||
- `hf download REPO_ID` — Download files from the Hub. `[--type CHOICE --revision TEXT --include TEXT --exclude TEXT --cache-dir TEXT --local-dir TEXT --force-download --dry-run --quiet --max-workers INTEGER]`
|
||||
- `hf env` — Print information about the environment.
|
||||
- `hf sync` — Sync files between local directory and a bucket. `[--delete --ignore-times --ignore-sizes --plan TEXT --apply TEXT --dry-run --include TEXT --exclude TEXT --filter-from TEXT --existing --ignore-existing --verbose --quiet]`
|
||||
- `hf upload REPO_ID` — Upload a file or a folder to the Hub. Recommended for single-commit uploads. `[--type CHOICE --revision TEXT --private --include TEXT --exclude TEXT --delete TEXT --commit-message TEXT --commit-description TEXT --create-pr --every FLOAT --quiet]`
|
||||
- `hf upload-large-folder REPO_ID LOCAL_PATH` — Upload a large folder to the Hub. Recommended for resumable uploads. `[--type CHOICE --revision TEXT --private --include TEXT --exclude TEXT --num-workers INTEGER --no-report --no-bars]`
|
||||
- `hf version` — Print information about the hf version.
|
||||
|
||||
### Download
|
||||
```bash
|
||||
hf download <repo_id> # Full repo to cache
|
||||
hf download <repo_id> file.safetensors # Specific file
|
||||
hf download <repo_id> --local-dir ./models # To local directory
|
||||
hf download <repo_id> --include "*.safetensors" # Filter by pattern
|
||||
hf download <repo_id> --repo-type dataset # Dataset
|
||||
hf download <repo_id> --revision v1.0 # Specific version
|
||||
```
|
||||
### `hf auth` — Manage authentication (login, logout, etc.).
|
||||
|
||||
### Upload
|
||||
```bash
|
||||
hf upload <repo_id> . . # Current dir to root
|
||||
hf upload <repo_id> ./models /weights # Folder to path
|
||||
hf upload <repo_id> model.safetensors # Single file
|
||||
hf upload <repo_id> . . --repo-type dataset # Dataset
|
||||
hf upload <repo_id> . . --create-pr # Create PR
|
||||
hf upload <repo_id> . . --commit-message="msg" # Custom message
|
||||
```
|
||||
- `hf auth list` — List all stored access tokens.
|
||||
- `hf auth login` — Login using a token from huggingface.co/settings/tokens. `[--add-to-git-credential --force]`
|
||||
- `hf auth logout` — Logout from a specific token. `[--token-name TEXT]`
|
||||
- `hf auth switch` — Switch between access tokens. `[--token-name TEXT --add-to-git-credential]`
|
||||
- `hf auth whoami` — Find out which huggingface.co account you are logged in as. `[--format CHOICE]`
|
||||
|
||||
### Repository Management
|
||||
```bash
|
||||
hf repo create <name> # Create model repo
|
||||
hf repo create <name> --repo-type dataset # Create dataset
|
||||
hf repo create <name> --private # Private repo
|
||||
hf repo create <name> --repo-type space --space_sdk gradio # Gradio space
|
||||
hf repo delete <repo_id> # Delete repo
|
||||
hf repo move <from_id> <to_id> # Move repo to new namespace
|
||||
hf repo settings <repo_id> --private true # Update repo settings
|
||||
hf repo list --repo-type model # List repos
|
||||
hf repo branch create <repo_id> release-v1 # Create branch
|
||||
hf repo branch delete <repo_id> release-v1 # Delete branch
|
||||
hf repo tag create <repo_id> v1.0 # Create tag
|
||||
hf repo tag list <repo_id> # List tags
|
||||
hf repo tag delete <repo_id> v1.0 # Delete tag
|
||||
```
|
||||
### `hf buckets` — Commands to interact with buckets.
|
||||
|
||||
### Delete Files from Repo
|
||||
```bash
|
||||
hf repo-files delete <repo_id> folder/ # Delete folder
|
||||
hf repo-files delete <repo_id> "*.txt" # Delete with pattern
|
||||
```
|
||||
- `hf buckets cp SRC` — Copy a single file to or from a bucket. `[--quiet]`
|
||||
- `hf buckets create BUCKET_ID` — Create a new bucket. `[--private --exist-ok --quiet]`
|
||||
- `hf buckets delete BUCKET_ID` — Delete a bucket. `[--yes --missing-ok --quiet]`
|
||||
- `hf buckets info BUCKET_ID` — Get info about a bucket. `[--quiet]`
|
||||
- `hf buckets list` — List buckets or files in a bucket. `[--human-readable --tree --recursive --format CHOICE --quiet]`
|
||||
- `hf buckets move FROM_ID TO_ID` — Move (rename) a bucket to a new name or namespace.
|
||||
- `hf buckets remove ARGUMENT` — Remove files from a bucket. `[--recursive --yes --dry-run --include TEXT --exclude TEXT --quiet]`
|
||||
- `hf buckets sync` — Sync files between local directory and a bucket. `[--delete --ignore-times --ignore-sizes --plan TEXT --apply TEXT --dry-run --include TEXT --exclude TEXT --filter-from TEXT --existing --ignore-existing --verbose --quiet]`
|
||||
|
||||
### Cache Management
|
||||
```bash
|
||||
hf cache ls # List cached repos
|
||||
hf cache ls --revisions # Include individual revisions
|
||||
hf cache rm model/gpt2 # Remove cached repo
|
||||
hf cache rm <revision_hash> # Remove cached revision
|
||||
hf cache prune # Remove detached revisions
|
||||
hf cache verify gpt2 # Verify checksums from cache
|
||||
```
|
||||
### `hf cache` — Manage local cache directory.
|
||||
|
||||
### Browse Hub
|
||||
```bash
|
||||
# Models
|
||||
hf models ls # List top trending models
|
||||
hf models ls --search "MiniMax" --author MiniMaxAI # Search models
|
||||
hf models ls --filter "text-generation" --limit 20 # Filter by task
|
||||
hf models info MiniMaxAI/MiniMax-M2.1 # Get model info
|
||||
- `hf cache list` — List cached repositories or revisions. `[--cache-dir TEXT --revisions --filter TEXT --format CHOICE --quiet --sort CHOICE --limit INTEGER]`
|
||||
- `hf cache prune` — Remove detached revisions from the cache. `[--cache-dir TEXT --yes --dry-run]`
|
||||
- `hf cache rm TARGETS` — Remove cached repositories or revisions. `[--cache-dir TEXT --yes --dry-run]`
|
||||
- `hf cache verify REPO_ID` — Verify checksums for a single repo revision from cache or a local directory. `[--type CHOICE --revision TEXT --cache-dir TEXT --local-dir TEXT --fail-on-missing-files --fail-on-extra-files]`
|
||||
|
||||
# Datasets
|
||||
hf datasets ls # List top trending datasets
|
||||
hf datasets ls --search "finepdfs" --sort downloads # Search datasets
|
||||
hf datasets info HuggingFaceFW/finepdfs # Get dataset info
|
||||
### `hf collections` — Interact with collections on the Hub.
|
||||
|
||||
# Spaces
|
||||
hf spaces ls # List top trending spaces
|
||||
hf spaces ls --filter "3d" --limit 10 # Filter by 3D modeling spaces
|
||||
hf spaces info enzostvs/deepsite # Get space info
|
||||
```
|
||||
- `hf collections add-item COLLECTION_SLUG ITEM_ID ITEM_TYPE` — Add an item to a collection. `[--note TEXT --exists-ok]`
|
||||
- `hf collections create TITLE` — Create a new collection on the Hub. `[--namespace TEXT --description TEXT --private --exists-ok]`
|
||||
- `hf collections delete COLLECTION_SLUG` — Delete a collection from the Hub. `[--missing-ok]`
|
||||
- `hf collections delete-item COLLECTION_SLUG ITEM_OBJECT_ID` — Delete an item from a collection. `[--missing-ok]`
|
||||
- `hf collections info COLLECTION_SLUG` — Get info about a collection on the Hub. Output is in JSON format.
|
||||
- `hf collections list` — List collections on the Hub. `[--owner TEXT --item TEXT --sort CHOICE --limit INTEGER --format CHOICE --quiet]`
|
||||
- `hf collections update COLLECTION_SLUG` — Update a collection's metadata on the Hub. `[--title TEXT --description TEXT --position INTEGER --private --theme TEXT]`
|
||||
- `hf collections update-item COLLECTION_SLUG ITEM_OBJECT_ID` — Update an item in a collection. `[--note TEXT --position INTEGER]`
|
||||
|
||||
### Jobs (Cloud Compute)
|
||||
```bash
|
||||
hf jobs run python:3.12 python script.py # Run on CPU
|
||||
hf jobs run --flavor a10g-small <image> <cmd> # Run on GPU
|
||||
hf jobs run --secrets HF_TOKEN <image> <cmd> # With HF token
|
||||
hf jobs ps # List jobs
|
||||
hf jobs logs <job_id> # View logs
|
||||
hf jobs cancel <job_id> # Cancel job
|
||||
```
|
||||
### `hf datasets` — Interact with datasets on the Hub.
|
||||
|
||||
### Inference Endpoints
|
||||
```bash
|
||||
hf endpoints ls # List endpoints
|
||||
hf endpoints deploy my-endpoint \
|
||||
--repo openai/gpt-oss-120b \
|
||||
--framework vllm \
|
||||
--accelerator gpu \
|
||||
--instance-size x4 \
|
||||
--instance-type nvidia-a10g \
|
||||
--region us-east-1 \
|
||||
--vendor aws
|
||||
hf endpoints describe my-endpoint # Show endpoint details
|
||||
hf endpoints pause my-endpoint # Pause endpoint
|
||||
hf endpoints resume my-endpoint # Resume endpoint
|
||||
hf endpoints scale-to-zero my-endpoint # Scale to zero
|
||||
hf endpoints delete my-endpoint --yes # Delete endpoint
|
||||
```
|
||||
**GPU Flavors:** `cpu-basic`, `cpu-upgrade`, `cpu-xl`, `t4-small`, `t4-medium`, `l4x1`, `l4x4`, `l40sx1`, `l40sx4`, `l40sx8`, `a10g-small`, `a10g-large`, `a10g-largex2`, `a10g-largex4`, `a100-large`, `h100`, `h100x8`
|
||||
- `hf datasets info DATASET_ID` — Get info about a dataset on the Hub. Output is in JSON format. `[--revision TEXT --expand TEXT]`
|
||||
- `hf datasets list` — List datasets on the Hub. `[--search TEXT --author TEXT --filter TEXT --sort CHOICE --limit INTEGER --expand TEXT --format CHOICE --quiet]`
|
||||
- `hf datasets parquet DATASET_ID` — List parquet file URLs available for a dataset. `[--subset TEXT --split TEXT --format CHOICE --quiet]`
|
||||
- `hf datasets sql SQL` — Execute a raw SQL query with DuckDB against dataset parquet URLs. `[--format CHOICE]`
|
||||
|
||||
## Common Patterns
|
||||
### `hf discussions` — Manage discussions and pull requests on the Hub.
|
||||
|
||||
### Download and Use Model Locally
|
||||
```bash
|
||||
# Download to local directory for deployment
|
||||
hf download meta-llama/Llama-3.2-1B-Instruct --local-dir ./model
|
||||
- `hf discussions close REPO_ID NUM` — Close a discussion or pull request. `[--comment TEXT --yes --type CHOICE]`
|
||||
- `hf discussions comment REPO_ID NUM` — Comment on a discussion or pull request. `[--body TEXT --body-file PATH --type CHOICE]`
|
||||
- `hf discussions create REPO_ID --title TEXT` — Create a new discussion or pull request on a repo. `[--body TEXT --body-file PATH --pull-request --type CHOICE]`
|
||||
- `hf discussions diff REPO_ID NUM` — Show the diff of a pull request. `[--type CHOICE]`
|
||||
- `hf discussions info REPO_ID NUM` — Get info about a discussion or pull request. `[--comments --diff --no-color --type CHOICE --format CHOICE]`
|
||||
- `hf discussions list REPO_ID` — List discussions and pull requests on a repo. `[--status CHOICE --kind CHOICE --author TEXT --limit INTEGER --type CHOICE --format CHOICE --quiet]`
|
||||
- `hf discussions merge REPO_ID NUM` — Merge a pull request. `[--comment TEXT --yes --type CHOICE]`
|
||||
- `hf discussions rename REPO_ID NUM NEW_TITLE` — Rename a discussion or pull request. `[--type CHOICE]`
|
||||
- `hf discussions reopen REPO_ID NUM` — Reopen a closed discussion or pull request. `[--comment TEXT --yes --type CHOICE]`
|
||||
|
||||
# Or use cache and get path
|
||||
MODEL_PATH=$(hf download meta-llama/Llama-3.2-1B-Instruct --quiet)
|
||||
```
|
||||
### `hf endpoints` — Manage Hugging Face Inference Endpoints.
|
||||
|
||||
### Publish Model/Dataset
|
||||
```bash
|
||||
hf repo create my-username/my-model --private
|
||||
hf upload my-username/my-model ./output . --commit-message="Initial release"
|
||||
hf repo tag create my-username/my-model v1.0
|
||||
```
|
||||
- `hf endpoints catalog deploy --repo TEXT` — Deploy an Inference Endpoint from the Model Catalog. `[--name TEXT --accelerator TEXT --namespace TEXT]`
|
||||
- `hf endpoints catalog list` — List available Catalog models.
|
||||
- `hf endpoints delete NAME` — Delete an Inference Endpoint permanently. `[--namespace TEXT --yes]`
|
||||
- `hf endpoints deploy NAME --repo TEXT --framework TEXT --accelerator TEXT --instance-size TEXT --instance-type TEXT --region TEXT --vendor TEXT` — Deploy an Inference Endpoint from a Hub repository. `[--namespace TEXT --task TEXT --min-replica INTEGER --max-replica INTEGER --scale-to-zero-timeout INTEGER --scaling-metric CHOICE --scaling-threshold FLOAT]`
|
||||
- `hf endpoints describe NAME` — Get information about an existing endpoint. `[--namespace TEXT]`
|
||||
- `hf endpoints list` — Lists all Inference Endpoints for the given namespace. `[--namespace TEXT --format CHOICE --quiet]`
|
||||
- `hf endpoints pause NAME` — Pause an Inference Endpoint. `[--namespace TEXT]`
|
||||
- `hf endpoints resume NAME` — Resume an Inference Endpoint. `[--namespace TEXT --fail-if-already-running]`
|
||||
- `hf endpoints scale-to-zero NAME` — Scale an Inference Endpoint to zero. `[--namespace TEXT]`
|
||||
- `hf endpoints update NAME` — Update an existing endpoint. `[--namespace TEXT --repo TEXT --accelerator TEXT --instance-size TEXT --instance-type TEXT --framework TEXT --revision TEXT --task TEXT --min-replica INTEGER --max-replica INTEGER --scale-to-zero-timeout INTEGER --scaling-metric CHOICE --scaling-threshold FLOAT]`
|
||||
|
||||
### Sync Space with Local
|
||||
```bash
|
||||
hf upload my-username/my-space . . --repo-type space \
|
||||
--exclude="logs/*" --delete="*" --commit-message="Sync"
|
||||
```
|
||||
### `hf extensions` — Manage hf CLI extensions.
|
||||
|
||||
### Check Cache Usage
|
||||
```bash
|
||||
hf cache ls # See all cached repos and sizes
|
||||
hf cache rm model/gpt2 # Remove a repo from cache
|
||||
```
|
||||
- `hf extensions exec NAME` — Execute an installed extension.
|
||||
- `hf extensions install REPO_ID` — Install an extension from a public GitHub repository. `[--force]`
|
||||
- `hf extensions list` — List installed extension commands. `[--format CHOICE --quiet]`
|
||||
- `hf extensions remove NAME` — Remove an installed extension.
|
||||
- `hf extensions search` — Search extensions available on GitHub (tagged with 'hf-extension' topic). `[--format CHOICE --quiet]`
|
||||
|
||||
## Key Options
|
||||
### `hf jobs` — Run and manage Jobs on the Hub.
|
||||
|
||||
- `--repo-type`: `model` (default), `dataset`, `space`
|
||||
- `--revision`: Branch, tag, or commit hash
|
||||
- `--token`: Override authentication
|
||||
- `--quiet`: Output only essential info (paths/URLs)
|
||||
- `hf jobs cancel JOB_ID` — Cancel a Job `[--namespace TEXT]`
|
||||
- `hf jobs hardware` — List available hardware options for Jobs
|
||||
- `hf jobs inspect JOB_IDS` — Display detailed information on one or more Jobs `[--namespace TEXT]`
|
||||
- `hf jobs logs JOB_ID` — Fetch the logs of a Job. `[--follow --tail INTEGER --namespace TEXT]`
|
||||
- `hf jobs ps` — List Jobs. `[--all --namespace TEXT --filter TEXT --format TEXT --quiet]`
|
||||
- `hf jobs run IMAGE COMMAND` — Run a Job. `[--env TEXT --secrets TEXT --label TEXT --volume TEXT --env-file TEXT --secrets-file TEXT --flavor CHOICE --timeout TEXT --detach --namespace TEXT]`
|
||||
- `hf jobs scheduled delete SCHEDULED_JOB_ID` — Delete a scheduled Job. `[--namespace TEXT]`
|
||||
- `hf jobs scheduled inspect SCHEDULED_JOB_IDS` — Display detailed information on one or more scheduled Jobs `[--namespace TEXT]`
|
||||
- `hf jobs scheduled ps` — List scheduled Jobs `[--all --namespace TEXT --filter TEXT --format TEXT --quiet]`
|
||||
- `hf jobs scheduled resume SCHEDULED_JOB_ID` — Resume (unpause) a scheduled Job. `[--namespace TEXT]`
|
||||
- `hf jobs scheduled run SCHEDULE IMAGE COMMAND` — Schedule a Job. `[--suspend --concurrency --env TEXT --secrets TEXT --label TEXT --volume TEXT --env-file TEXT --secrets-file TEXT --flavor CHOICE --timeout TEXT --namespace TEXT]`
|
||||
- `hf jobs scheduled suspend SCHEDULED_JOB_ID` — Suspend (pause) a scheduled Job. `[--namespace TEXT]`
|
||||
- `hf jobs scheduled uv run SCHEDULE SCRIPT` — Run a UV script (local file or URL) on HF infrastructure `[--suspend --concurrency --image TEXT --flavor CHOICE --env TEXT --secrets TEXT --label TEXT --volume TEXT --env-file TEXT --secrets-file TEXT --timeout TEXT --namespace TEXT --with TEXT --python TEXT]`
|
||||
- `hf jobs stats` — Fetch the resource usage statistics and metrics of Jobs `[--namespace TEXT]`
|
||||
- `hf jobs uv run SCRIPT` — Run a UV script (local file or URL) on HF infrastructure `[--image TEXT --flavor CHOICE --env TEXT --secrets TEXT --label TEXT --volume TEXT --env-file TEXT --secrets-file TEXT --timeout TEXT --detach --namespace TEXT --with TEXT --python TEXT]`
|
||||
|
||||
## References
|
||||
### `hf models` — Interact with models on the Hub.
|
||||
|
||||
- **Complete command reference**: See references/commands.md
|
||||
- **Workflow examples**: See references/examples.md
|
||||
- `hf models info MODEL_ID` — Get info about a model on the Hub. Output is in JSON format. `[--revision TEXT --expand TEXT]`
|
||||
- `hf models list` — List models on the Hub. `[--search TEXT --author TEXT --filter TEXT --num-parameters TEXT --sort CHOICE --limit INTEGER --expand TEXT --format CHOICE --quiet]`
|
||||
|
||||
### `hf papers` — Interact with papers on the Hub.
|
||||
|
||||
- `hf papers info PAPER_ID` — Get info about a paper on the Hub. Output is in JSON format.
|
||||
- `hf papers list` — List daily papers on the Hub. `[--date TEXT --week TEXT --month TEXT --submitter TEXT --sort CHOICE --limit INTEGER --format CHOICE --quiet]`
|
||||
- `hf papers read PAPER_ID` — Read a paper as markdown.
|
||||
- `hf papers search QUERY` — Search papers on the Hub. `[--limit INTEGER --format CHOICE --quiet]`
|
||||
|
||||
### `hf repos` — Manage repos on the Hub.
|
||||
|
||||
- `hf repos branch create REPO_ID BRANCH` — Create a new branch for a repo on the Hub. `[--revision TEXT --type CHOICE --exist-ok]`
|
||||
- `hf repos branch delete REPO_ID BRANCH` — Delete a branch from a repo on the Hub. `[--type CHOICE]`
|
||||
- `hf repos create REPO_ID` — Create a new repo on the Hub. `[--type CHOICE --space-sdk TEXT --private --public --protected --exist-ok --resource-group-id TEXT --flavor TEXT --storage TEXT --sleep-time INTEGER --secrets TEXT --secrets-file TEXT --env TEXT --env-file TEXT]`
|
||||
- `hf repos delete REPO_ID` — Delete a repo from the Hub. This is an irreversible operation. `[--type CHOICE --missing-ok]`
|
||||
- `hf repos delete-files REPO_ID PATTERNS` — Delete files from a repo on the Hub. `[--type CHOICE --revision TEXT --commit-message TEXT --commit-description TEXT --create-pr]`
|
||||
- `hf repos duplicate FROM_ID` — Duplicate a repo on the Hub (model, dataset, or Space). `[--type CHOICE --private --public --protected --exist-ok --flavor TEXT --storage TEXT --sleep-time INTEGER --secrets TEXT --secrets-file TEXT --env TEXT --env-file TEXT]`
|
||||
- `hf repos move FROM_ID TO_ID` — Move a repository from a namespace to another namespace. `[--type CHOICE]`
|
||||
- `hf repos settings REPO_ID` — Update the settings of a repository. `[--gated CHOICE --private --public --protected --type CHOICE]`
|
||||
- `hf repos tag create REPO_ID TAG` — Create a tag for a repo. `[--message TEXT --revision TEXT --type CHOICE]`
|
||||
- `hf repos tag delete REPO_ID TAG` — Delete a tag for a repo. `[--yes --type CHOICE]`
|
||||
- `hf repos tag list REPO_ID` — List tags for a repo. `[--type CHOICE]`
|
||||
|
||||
### `hf skills` — Manage skills for AI assistants.
|
||||
|
||||
- `hf skills add` — Download a skill and install it for an AI assistant. `[--claude --codex --cursor --opencode --global --dest PATH --force]`
|
||||
- `hf skills preview` — Print the generated SKILL.md to stdout.
|
||||
|
||||
### `hf spaces` — Interact with spaces on the Hub.
|
||||
|
||||
- `hf spaces dev-mode SPACE_ID` — Enable or disable dev mode on a Space. `[--stop]`
|
||||
- `hf spaces hot-reload SPACE_ID` — Hot-reload any Python file of a Space without a full rebuild + restart. `[--local-file TEXT --skip-checks --skip-summary]`
|
||||
- `hf spaces info SPACE_ID` — Get info about a space on the Hub. Output is in JSON format. `[--revision TEXT --expand TEXT]`
|
||||
- `hf spaces list` — List spaces on the Hub. `[--search TEXT --author TEXT --filter TEXT --sort CHOICE --limit INTEGER --expand TEXT --format CHOICE --quiet]`
|
||||
|
||||
### `hf webhooks` — Manage webhooks on the Hub.
|
||||
|
||||
- `hf webhooks create --watch TEXT` — Create a new webhook. `[--url TEXT --job-id TEXT --domain CHOICE --secret TEXT]`
|
||||
- `hf webhooks delete WEBHOOK_ID` — Delete a webhook permanently. `[--yes]`
|
||||
- `hf webhooks disable WEBHOOK_ID` — Disable an active webhook.
|
||||
- `hf webhooks enable WEBHOOK_ID` — Enable a disabled webhook.
|
||||
- `hf webhooks info WEBHOOK_ID` — Show full details for a single webhook as JSON.
|
||||
- `hf webhooks list` — List all webhooks for the current user. `[--format CHOICE --quiet]`
|
||||
- `hf webhooks update WEBHOOK_ID` — Update an existing webhook. Only provided options are changed. `[--url TEXT --watch TEXT --domain CHOICE --secret TEXT]`
|
||||
|
||||
## Common options
|
||||
|
||||
- `--format` — Output format: `--format json` (or `--json`) or `--format table` (default).
|
||||
- `-q / --quiet` — Minimal output.
|
||||
- `--revision` — Git revision id which can be a branch name, a tag, or a commit hash.
|
||||
- `--token` — Use a User Access Token. Prefer setting `HF_TOKEN` env var instead of passing `--token`.
|
||||
- `--type` — The type of repository (model, dataset, or space).
|
||||
|
||||
## Mounting repos as local filesystems
|
||||
|
||||
To mount Hub repositories or buckets as local filesystems — no download, no copy, no waiting — use `hf-mount`. Files are fetched on demand. GitHub: https://github.com/huggingface/hf-mount
|
||||
|
||||
Install: `curl -fsSL https://raw.githubusercontent.com/huggingface/hf-mount/main/install.sh | sh`
|
||||
|
||||
Some command examples:
|
||||
- `hf-mount start repo openai-community/gpt2 /tmp/gpt2` — mount a repo (read-only)
|
||||
- `hf-mount start --hf-token $HF_TOKEN bucket myuser/my-bucket /tmp/data` — mount a bucket (read-write)
|
||||
- `hf-mount status` / `hf-mount stop /tmp/data` — list or unmount
|
||||
|
||||
## Tips
|
||||
|
||||
- Use `hf <command> --help` for full options, descriptions, usage, and real-world examples
|
||||
- Authenticate with `HF_TOKEN` env var (recommended) or with `--token`
|
||||
|
||||
@@ -0,0 +1,213 @@
|
||||
---
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/huggingface-community-evals"
|
||||
name: hugging-face-community-evals
|
||||
description: Run local evaluations for Hugging Face Hub models with inspect-ai or lighteval.
|
||||
risk: unknown
|
||||
---
|
||||
|
||||
# Overview
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill for local model evaluation, backend selection, and GPU smoke tests outside the Hugging Face Jobs workflow.
|
||||
|
||||
This skill is for **running evaluations against models on the Hugging Face Hub on local hardware**.
|
||||
|
||||
It covers:
|
||||
- `inspect-ai` with local inference
|
||||
- `lighteval` with local inference
|
||||
- choosing between `vllm`, Hugging Face Transformers, and `accelerate`
|
||||
- smoke tests, task selection, and backend fallback strategy
|
||||
|
||||
It does **not** cover:
|
||||
- Hugging Face Jobs orchestration
|
||||
- model-card or `model-index` edits
|
||||
- README table extraction
|
||||
- Artificial Analysis imports
|
||||
- `.eval_results` generation or publishing
|
||||
- PR creation or community-evals automation
|
||||
|
||||
If the user wants to **run the same eval remotely on Hugging Face Jobs**, hand off to the `hugging-face-jobs` skill and pass it one of the local scripts in this skill.
|
||||
|
||||
If the user wants to **publish results into the community evals workflow**, stop after generating the evaluation run and hand off that publishing step to `~/code/community-evals`.
|
||||
|
||||
> All paths below are relative to the directory containing this `SKILL.md`.
|
||||
|
||||
# When To Use Which Script
|
||||
|
||||
| Use case | Script |
|
||||
|---|---|
|
||||
| Local `inspect-ai` eval on a Hub model via inference providers | `scripts/inspect_eval_uv.py` |
|
||||
| Local GPU eval with `inspect-ai` using `vllm` or Transformers | `scripts/inspect_vllm_uv.py` |
|
||||
| Local GPU eval with `lighteval` using `vllm` or `accelerate` | `scripts/lighteval_vllm_uv.py` |
|
||||
| Extra command patterns | `examples/USAGE_EXAMPLES.md` |
|
||||
|
||||
# Prerequisites
|
||||
|
||||
- Prefer `uv run` for local execution.
|
||||
- Set `HF_TOKEN` for gated/private models.
|
||||
- For local GPU runs, verify GPU access before starting:
|
||||
|
||||
```bash
|
||||
uv --version
|
||||
printenv HF_TOKEN >/dev/null
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
If `nvidia-smi` is unavailable, either:
|
||||
- use `scripts/inspect_eval_uv.py` for lighter provider-backed evaluation, or
|
||||
- hand off to the `hugging-face-jobs` skill if the user wants remote compute.
|
||||
|
||||
# Core Workflow
|
||||
|
||||
1. Choose the evaluation framework.
|
||||
- Use `inspect-ai` when you want explicit task control and inspect-native flows.
|
||||
- Use `lighteval` when the benchmark is naturally expressed as a lighteval task string, especially leaderboard-style tasks.
|
||||
2. Choose the inference backend.
|
||||
- Prefer `vllm` for throughput on supported architectures.
|
||||
- Use Hugging Face Transformers (`--backend hf`) or `accelerate` as compatibility fallbacks.
|
||||
3. Start with a smoke test.
|
||||
- `inspect-ai`: add `--limit 10` or similar.
|
||||
- `lighteval`: add `--max-samples 10`.
|
||||
4. Scale up only after the smoke test passes.
|
||||
5. If the user wants remote execution, hand off to `hugging-face-jobs` with the same script + args.
|
||||
|
||||
# Quick Start
|
||||
|
||||
## Option A: inspect-ai with local inference providers path
|
||||
|
||||
Best when the model is already supported by Hugging Face Inference Providers and you want the lowest local setup overhead.
|
||||
|
||||
```bash
|
||||
uv run scripts/inspect_eval_uv.py \
|
||||
--model meta-llama/Llama-3.2-1B \
|
||||
--task mmlu \
|
||||
--limit 20
|
||||
```
|
||||
|
||||
Use this path when:
|
||||
- you want a quick local smoke test
|
||||
- you do not need direct GPU control
|
||||
- the task already exists in `inspect-evals`
|
||||
|
||||
## Option B: inspect-ai on Local GPU
|
||||
|
||||
Best when you need to load the Hub model directly, use `vllm`, or fall back to Transformers for unsupported architectures.
|
||||
|
||||
Local GPU:
|
||||
|
||||
```bash
|
||||
uv run scripts/inspect_vllm_uv.py \
|
||||
--model meta-llama/Llama-3.2-1B \
|
||||
--task gsm8k \
|
||||
--limit 20
|
||||
```
|
||||
|
||||
Transformers fallback:
|
||||
|
||||
```bash
|
||||
uv run scripts/inspect_vllm_uv.py \
|
||||
--model microsoft/phi-2 \
|
||||
--task mmlu \
|
||||
--backend hf \
|
||||
--trust-remote-code \
|
||||
--limit 20
|
||||
```
|
||||
|
||||
## Option C: lighteval on Local GPU
|
||||
|
||||
Best when the task is naturally expressed as a `lighteval` task string, especially Open LLM Leaderboard style benchmarks.
|
||||
|
||||
Local GPU:
|
||||
|
||||
```bash
|
||||
uv run scripts/lighteval_vllm_uv.py \
|
||||
--model meta-llama/Llama-3.2-3B-Instruct \
|
||||
--tasks "leaderboard|mmlu|5,leaderboard|gsm8k|5" \
|
||||
--max-samples 20 \
|
||||
--use-chat-template
|
||||
```
|
||||
|
||||
`accelerate` fallback:
|
||||
|
||||
```bash
|
||||
uv run scripts/lighteval_vllm_uv.py \
|
||||
--model microsoft/phi-2 \
|
||||
--tasks "leaderboard|mmlu|5" \
|
||||
--backend accelerate \
|
||||
--trust-remote-code \
|
||||
--max-samples 20
|
||||
```
|
||||
|
||||
# Remote Execution Boundary
|
||||
|
||||
This skill intentionally stops at **local execution and backend selection**.
|
||||
|
||||
If the user wants to:
|
||||
- run these scripts on Hugging Face Jobs
|
||||
- pick remote hardware
|
||||
- pass secrets to remote jobs
|
||||
- schedule recurring runs
|
||||
- inspect / cancel / monitor jobs
|
||||
|
||||
then switch to the **`hugging-face-jobs`** skill and pass it one of these scripts plus the chosen arguments.
|
||||
|
||||
# Task Selection
|
||||
|
||||
`inspect-ai` examples:
|
||||
- `mmlu`
|
||||
- `gsm8k`
|
||||
- `hellaswag`
|
||||
- `arc_challenge`
|
||||
- `truthfulqa`
|
||||
- `winogrande`
|
||||
- `humaneval`
|
||||
|
||||
`lighteval` task strings use `suite|task|num_fewshot`:
|
||||
- `leaderboard|mmlu|5`
|
||||
- `leaderboard|gsm8k|5`
|
||||
- `leaderboard|arc_challenge|25`
|
||||
- `lighteval|hellaswag|0`
|
||||
|
||||
Multiple `lighteval` tasks can be comma-separated in `--tasks`.
|
||||
|
||||
# Backend Selection
|
||||
|
||||
- Prefer `inspect_vllm_uv.py --backend vllm` for fast GPU inference on supported architectures.
|
||||
- Use `inspect_vllm_uv.py --backend hf` when `vllm` does not support the model.
|
||||
- Prefer `lighteval_vllm_uv.py --backend vllm` for throughput on supported models.
|
||||
- Use `lighteval_vllm_uv.py --backend accelerate` as the compatibility fallback.
|
||||
- Use `inspect_eval_uv.py` when Inference Providers already cover the model and you do not need direct GPU control.
|
||||
|
||||
# Hardware Guidance
|
||||
|
||||
| Model size | Suggested local hardware |
|
||||
|---|---|
|
||||
| `< 3B` | consumer GPU / Apple Silicon / small dev GPU |
|
||||
| `3B - 13B` | stronger local GPU |
|
||||
| `13B+` | high-memory local GPU or hand off to `hugging-face-jobs` |
|
||||
|
||||
For smoke tests, prefer cheaper local runs plus `--limit` or `--max-samples`.
|
||||
|
||||
# Troubleshooting
|
||||
|
||||
- CUDA or vLLM OOM:
|
||||
- reduce `--batch-size`
|
||||
- reduce `--gpu-memory-utilization`
|
||||
- switch to a smaller model for the smoke test
|
||||
- if necessary, hand off to `hugging-face-jobs`
|
||||
- Model unsupported by `vllm`:
|
||||
- switch to `--backend hf` for `inspect-ai`
|
||||
- switch to `--backend accelerate` for `lighteval`
|
||||
- Gated/private repo access fails:
|
||||
- verify `HF_TOKEN`
|
||||
- Custom model code required:
|
||||
- add `--trust-remote-code`
|
||||
|
||||
# Examples
|
||||
|
||||
See:
|
||||
- `examples/USAGE_EXAMPLES.md` for local command patterns
|
||||
- `scripts/inspect_eval_uv.py`
|
||||
- `scripts/inspect_vllm_uv.py`
|
||||
- `scripts/lighteval_vllm_uv.py`
|
||||
@@ -0,0 +1,3 @@
|
||||
# Hugging Face Token (required for gated/private models)
|
||||
# Get your token at: https://huggingface.co/settings/tokens
|
||||
HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
@@ -0,0 +1,101 @@
|
||||
# Usage Examples
|
||||
|
||||
This document provides practical examples for **running evaluations locally** against Hugging Face Hub models.
|
||||
|
||||
## What this skill covers
|
||||
|
||||
- `inspect-ai` local runs
|
||||
- `inspect-ai` with `vllm` or Transformers backends
|
||||
- `lighteval` local runs with `vllm` or `accelerate`
|
||||
- smoke tests and backend fallback patterns
|
||||
|
||||
## What this skill does NOT cover
|
||||
|
||||
- `model-index`
|
||||
- `.eval_results`
|
||||
- community eval publication workflows
|
||||
- model-card PR creation
|
||||
- Hugging Face Jobs orchestration
|
||||
|
||||
If you want to run these same scripts remotely, use the `hugging-face-jobs` skill and pass one of the scripts in `scripts/`.
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
cd skills/hugging-face-evaluation
|
||||
export HF_TOKEN=hf_xxx
|
||||
uv --version
|
||||
```
|
||||
|
||||
For local GPU runs:
|
||||
|
||||
```bash
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
## inspect-ai examples
|
||||
|
||||
### Quick smoke test
|
||||
|
||||
```bash
|
||||
uv run scripts/inspect_eval_uv.py \
|
||||
--model meta-llama/Llama-3.2-1B \
|
||||
--task mmlu \
|
||||
--limit 10
|
||||
```
|
||||
|
||||
### Local GPU with vLLM
|
||||
|
||||
```bash
|
||||
uv run scripts/inspect_vllm_uv.py \
|
||||
--model meta-llama/Llama-3.2-8B-Instruct \
|
||||
--task gsm8k \
|
||||
--limit 20
|
||||
```
|
||||
|
||||
### Transformers fallback
|
||||
|
||||
```bash
|
||||
uv run scripts/inspect_vllm_uv.py \
|
||||
--model microsoft/phi-2 \
|
||||
--task mmlu \
|
||||
--backend hf \
|
||||
--trust-remote-code \
|
||||
--limit 20
|
||||
```
|
||||
|
||||
## lighteval examples
|
||||
|
||||
### Single task
|
||||
|
||||
```bash
|
||||
uv run scripts/lighteval_vllm_uv.py \
|
||||
--model meta-llama/Llama-3.2-3B-Instruct \
|
||||
--tasks "leaderboard|mmlu|5" \
|
||||
--max-samples 20
|
||||
```
|
||||
|
||||
### Multiple tasks
|
||||
|
||||
```bash
|
||||
uv run scripts/lighteval_vllm_uv.py \
|
||||
--model meta-llama/Llama-3.2-3B-Instruct \
|
||||
--tasks "leaderboard|mmlu|5,leaderboard|gsm8k|5" \
|
||||
--max-samples 20 \
|
||||
--use-chat-template
|
||||
```
|
||||
|
||||
### accelerate fallback
|
||||
|
||||
```bash
|
||||
uv run scripts/lighteval_vllm_uv.py \
|
||||
--model microsoft/phi-2 \
|
||||
--tasks "leaderboard|mmlu|5" \
|
||||
--backend accelerate \
|
||||
--trust-remote-code \
|
||||
--max-samples 20
|
||||
```
|
||||
|
||||
## Hand-off to Hugging Face Jobs
|
||||
|
||||
When local hardware is not enough, switch to the `hugging-face-jobs` skill and run one of these scripts remotely. Keep the script path and args; move the orchestration there.
|
||||
@@ -0,0 +1,104 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "inspect-ai>=0.3.0",
|
||||
# "inspect-evals",
|
||||
# "openai",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Entry point script for running inspect-ai evaluations against Hugging Face inference providers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _inspect_evals_tasks_root() -> Optional[Path]:
|
||||
"""Return the installed inspect_evals package path if available."""
|
||||
try:
|
||||
import inspect_evals
|
||||
|
||||
return Path(inspect_evals.__file__).parent
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_task(task: str) -> str:
|
||||
"""Allow lighteval-style `suite|task|shots` strings by keeping the task name."""
|
||||
if "|" in task:
|
||||
parts = task.split("|")
|
||||
if len(parts) >= 2 and parts[1]:
|
||||
return parts[1]
|
||||
return task
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Inspect-ai job runner")
|
||||
parser.add_argument("--model", required=True, help="Model ID on Hugging Face Hub")
|
||||
parser.add_argument("--task", required=True, help="inspect-ai task to execute")
|
||||
parser.add_argument("--limit", type=int, default=None, help="Limit number of samples to evaluate")
|
||||
parser.add_argument(
|
||||
"--tasks-root",
|
||||
default=None,
|
||||
help="Optional path to inspect task files. Defaults to the installed inspect_evals package.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sandbox",
|
||||
default="local",
|
||||
help="Sandbox backend to use (default: local for HF jobs without Docker).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Ensure downstream libraries can read the token passed as a secret
|
||||
hf_token = os.getenv("HF_TOKEN")
|
||||
if hf_token:
|
||||
os.environ.setdefault("HUGGING_FACE_HUB_TOKEN", hf_token)
|
||||
os.environ.setdefault("HF_HUB_TOKEN", hf_token)
|
||||
|
||||
task = _normalize_task(args.task)
|
||||
tasks_root = Path(args.tasks_root) if args.tasks_root else _inspect_evals_tasks_root()
|
||||
if tasks_root and not tasks_root.exists():
|
||||
tasks_root = None
|
||||
|
||||
cmd = [
|
||||
"inspect",
|
||||
"eval",
|
||||
task,
|
||||
"--model",
|
||||
f"hf-inference-providers/{args.model}",
|
||||
"--log-level",
|
||||
"info",
|
||||
# Reduce batch size to avoid OOM errors (default is 32)
|
||||
"--max-connections",
|
||||
"1",
|
||||
# Set a small positive temperature (HF doesn't allow temperature=0)
|
||||
"--temperature",
|
||||
"0.001",
|
||||
]
|
||||
|
||||
if args.sandbox:
|
||||
cmd.extend(["--sandbox", args.sandbox])
|
||||
|
||||
if args.limit:
|
||||
cmd.extend(["--limit", str(args.limit)])
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True, cwd=tasks_root)
|
||||
print("Evaluation complete.")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
location = f" (cwd={tasks_root})" if tasks_root else ""
|
||||
print(f"Evaluation failed with exit code {exc.returncode}{location}", file=sys.stderr)
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,306 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "inspect-ai>=0.3.0",
|
||||
# "inspect-evals",
|
||||
# "vllm>=0.4.0",
|
||||
# "torch>=2.0.0",
|
||||
# "transformers>=4.40.0",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Entry point script for running inspect-ai evaluations with vLLM or HuggingFace Transformers backend.
|
||||
|
||||
This script runs evaluations on custom HuggingFace models using local GPU inference,
|
||||
separate from inference provider scripts (which use external APIs).
|
||||
|
||||
Usage (standalone):
|
||||
uv run scripts/inspect_vllm_uv.py --model "meta-llama/Llama-3.2-1B" --task "mmlu"
|
||||
|
||||
Model backends:
|
||||
- vllm: Fast inference with vLLM (recommended for large models)
|
||||
- hf: HuggingFace Transformers backend (broader model compatibility)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def setup_environment() -> None:
|
||||
"""Configure environment variables for HuggingFace authentication."""
|
||||
hf_token = os.getenv("HF_TOKEN")
|
||||
if hf_token:
|
||||
os.environ.setdefault("HUGGING_FACE_HUB_TOKEN", hf_token)
|
||||
os.environ.setdefault("HF_HUB_TOKEN", hf_token)
|
||||
|
||||
|
||||
def run_inspect_vllm(
|
||||
model_id: str,
|
||||
task: str,
|
||||
limit: Optional[int] = None,
|
||||
max_connections: int = 4,
|
||||
temperature: float = 0.0,
|
||||
tensor_parallel_size: int = 1,
|
||||
gpu_memory_utilization: float = 0.8,
|
||||
dtype: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
log_level: str = "info",
|
||||
) -> None:
|
||||
"""
|
||||
Run inspect-ai evaluation with vLLM backend.
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID
|
||||
task: inspect-ai task to execute (e.g., "mmlu", "gsm8k")
|
||||
limit: Limit number of samples to evaluate
|
||||
max_connections: Maximum concurrent connections
|
||||
temperature: Sampling temperature
|
||||
tensor_parallel_size: Number of GPUs for tensor parallelism
|
||||
gpu_memory_utilization: GPU memory fraction
|
||||
dtype: Data type (auto, float16, bfloat16)
|
||||
trust_remote_code: Allow remote code execution
|
||||
log_level: Logging level
|
||||
"""
|
||||
setup_environment()
|
||||
|
||||
model_spec = f"vllm/{model_id}"
|
||||
cmd = [
|
||||
"inspect",
|
||||
"eval",
|
||||
task,
|
||||
"--model",
|
||||
model_spec,
|
||||
"--log-level",
|
||||
log_level,
|
||||
"--max-connections",
|
||||
str(max_connections),
|
||||
]
|
||||
|
||||
# vLLM supports temperature=0 unlike HF inference providers
|
||||
cmd.extend(["--temperature", str(temperature)])
|
||||
|
||||
# Older inspect-ai CLI versions do not support --model-args; rely on defaults
|
||||
# and let vLLM choose sensible settings for small models.
|
||||
if tensor_parallel_size != 1:
|
||||
cmd.extend(["--tensor-parallel-size", str(tensor_parallel_size)])
|
||||
if gpu_memory_utilization != 0.8:
|
||||
cmd.extend(["--gpu-memory-utilization", str(gpu_memory_utilization)])
|
||||
if dtype != "auto":
|
||||
cmd.extend(["--dtype", dtype])
|
||||
if trust_remote_code:
|
||||
cmd.append("--trust-remote-code")
|
||||
|
||||
if limit:
|
||||
cmd.extend(["--limit", str(limit)])
|
||||
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
print("Evaluation complete.")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
print(f"Evaluation failed with exit code {exc.returncode}", file=sys.stderr)
|
||||
sys.exit(exc.returncode)
|
||||
|
||||
|
||||
def run_inspect_hf(
|
||||
model_id: str,
|
||||
task: str,
|
||||
limit: Optional[int] = None,
|
||||
max_connections: int = 1,
|
||||
temperature: float = 0.001,
|
||||
device: str = "auto",
|
||||
dtype: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
log_level: str = "info",
|
||||
) -> None:
|
||||
"""
|
||||
Run inspect-ai evaluation with HuggingFace Transformers backend.
|
||||
|
||||
Use this when vLLM doesn't support the model architecture.
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID
|
||||
task: inspect-ai task to execute
|
||||
limit: Limit number of samples
|
||||
max_connections: Maximum concurrent connections (keep low for memory)
|
||||
temperature: Sampling temperature
|
||||
device: Device to use (auto, cuda, cpu)
|
||||
dtype: Data type
|
||||
trust_remote_code: Allow remote code execution
|
||||
log_level: Logging level
|
||||
"""
|
||||
setup_environment()
|
||||
|
||||
model_spec = f"hf/{model_id}"
|
||||
|
||||
cmd = [
|
||||
"inspect",
|
||||
"eval",
|
||||
task,
|
||||
"--model",
|
||||
model_spec,
|
||||
"--log-level",
|
||||
log_level,
|
||||
"--max-connections",
|
||||
str(max_connections),
|
||||
"--temperature",
|
||||
str(temperature),
|
||||
]
|
||||
|
||||
if device != "auto":
|
||||
cmd.extend(["--device", device])
|
||||
if dtype != "auto":
|
||||
cmd.extend(["--dtype", dtype])
|
||||
if trust_remote_code:
|
||||
cmd.append("--trust-remote-code")
|
||||
|
||||
if limit:
|
||||
cmd.extend(["--limit", str(limit)])
|
||||
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
print("Evaluation complete.")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
print(f"Evaluation failed with exit code {exc.returncode}", file=sys.stderr)
|
||||
sys.exit(exc.returncode)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run inspect-ai evaluations with vLLM or HuggingFace Transformers on custom models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Run MMLU with vLLM backend
|
||||
uv run scripts/inspect_vllm_uv.py --model meta-llama/Llama-3.2-1B --task mmlu
|
||||
|
||||
# Run with HuggingFace Transformers backend
|
||||
uv run scripts/inspect_vllm_uv.py --model meta-llama/Llama-3.2-1B --task mmlu --backend hf
|
||||
|
||||
# Run with limited samples for testing
|
||||
uv run scripts/inspect_vllm_uv.py --model meta-llama/Llama-3.2-1B --task mmlu --limit 10
|
||||
|
||||
# Run on multiple GPUs with tensor parallelism
|
||||
uv run scripts/inspect_vllm_uv.py --model meta-llama/Llama-3.2-70B --task mmlu --tensor-parallel-size 4
|
||||
|
||||
Available tasks (from inspect-evals):
|
||||
- mmlu: Massive Multitask Language Understanding
|
||||
- gsm8k: Grade School Math
|
||||
- hellaswag: Common sense reasoning
|
||||
- arc_challenge: AI2 Reasoning Challenge
|
||||
- truthfulqa: TruthfulQA benchmark
|
||||
- winogrande: Winograd Schema Challenge
|
||||
- humaneval: Code generation (HumanEval)
|
||||
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
required=True,
|
||||
help="HuggingFace model ID (e.g., meta-llama/Llama-3.2-1B)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
required=True,
|
||||
help="inspect-ai task to execute (e.g., mmlu, gsm8k)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["vllm", "hf"],
|
||||
default="vllm",
|
||||
help="Model backend (default: vllm)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Limit number of samples to evaluate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-connections",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum concurrent connections (default: 4 for vllm, 1 for hf)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Sampling temperature (default: 0.0 for vllm, 0.001 for hf)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tensor-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for tensor parallelism (vLLM only, default: 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="GPU memory fraction to use (vLLM only, default: 0.8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="auto",
|
||||
choices=["auto", "float16", "bfloat16", "float32"],
|
||||
help="Data type for model weights (default: auto)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="auto",
|
||||
help="Device for HF backend (auto, cuda, cpu)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Allow executing remote code from model repository",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
default="info",
|
||||
choices=["debug", "info", "warning", "error"],
|
||||
help="Logging level (default: info)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.backend == "vllm":
|
||||
run_inspect_vllm(
|
||||
model_id=args.model,
|
||||
task=args.task,
|
||||
limit=args.limit,
|
||||
max_connections=args.max_connections or 4,
|
||||
temperature=args.temperature if args.temperature is not None else 0.0,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
dtype=args.dtype,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
log_level=args.log_level,
|
||||
)
|
||||
else:
|
||||
run_inspect_hf(
|
||||
model_id=args.model,
|
||||
task=args.task,
|
||||
limit=args.limit,
|
||||
max_connections=args.max_connections or 1,
|
||||
temperature=args.temperature if args.temperature is not None else 0.001,
|
||||
device=args.device,
|
||||
dtype=args.dtype,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
log_level=args.log_level,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,297 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "lighteval[accelerate,vllm]>=0.6.0",
|
||||
# "torch>=2.0.0",
|
||||
# "transformers>=4.40.0",
|
||||
# "accelerate>=0.30.0",
|
||||
# "vllm>=0.4.0",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Entry point script for running lighteval evaluations with local GPU backends.
|
||||
|
||||
This script runs evaluations using vLLM or accelerate on custom HuggingFace models.
|
||||
It is separate from inference provider scripts and evaluates models directly on local hardware.
|
||||
|
||||
Usage (standalone):
|
||||
uv run scripts/lighteval_vllm_uv.py --model "meta-llama/Llama-3.2-1B" --tasks "leaderboard|mmlu|5"
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def setup_environment() -> None:
|
||||
"""Configure environment variables for HuggingFace authentication."""
|
||||
hf_token = os.getenv("HF_TOKEN")
|
||||
if hf_token:
|
||||
os.environ.setdefault("HUGGING_FACE_HUB_TOKEN", hf_token)
|
||||
os.environ.setdefault("HF_HUB_TOKEN", hf_token)
|
||||
|
||||
|
||||
def run_lighteval_vllm(
|
||||
model_id: str,
|
||||
tasks: str,
|
||||
output_dir: Optional[str] = None,
|
||||
max_samples: Optional[int] = None,
|
||||
batch_size: int = 1,
|
||||
tensor_parallel_size: int = 1,
|
||||
gpu_memory_utilization: float = 0.8,
|
||||
dtype: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
use_chat_template: bool = False,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run lighteval with vLLM backend for efficient GPU inference.
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID (e.g., "meta-llama/Llama-3.2-1B")
|
||||
tasks: Task specification (e.g., "leaderboard|mmlu|5" or "lighteval|hellaswag|0")
|
||||
output_dir: Directory for evaluation results
|
||||
max_samples: Limit number of samples per task
|
||||
batch_size: Batch size for evaluation
|
||||
tensor_parallel_size: Number of GPUs for tensor parallelism
|
||||
gpu_memory_utilization: GPU memory fraction to use (0.0-1.0)
|
||||
dtype: Data type for model weights (auto, float16, bfloat16)
|
||||
trust_remote_code: Allow executing remote code from model repo
|
||||
use_chat_template: Apply chat template for conversational models
|
||||
system_prompt: System prompt for chat models
|
||||
"""
|
||||
setup_environment()
|
||||
|
||||
# Build lighteval vllm command
|
||||
cmd = [
|
||||
"lighteval",
|
||||
"vllm",
|
||||
model_id,
|
||||
tasks,
|
||||
"--batch-size", str(batch_size),
|
||||
"--tensor-parallel-size", str(tensor_parallel_size),
|
||||
"--gpu-memory-utilization", str(gpu_memory_utilization),
|
||||
"--dtype", dtype,
|
||||
]
|
||||
|
||||
if output_dir:
|
||||
cmd.extend(["--output-dir", output_dir])
|
||||
|
||||
if max_samples:
|
||||
cmd.extend(["--max-samples", str(max_samples)])
|
||||
|
||||
if trust_remote_code:
|
||||
cmd.append("--trust-remote-code")
|
||||
|
||||
if use_chat_template:
|
||||
cmd.append("--use-chat-template")
|
||||
|
||||
if system_prompt:
|
||||
cmd.extend(["--system-prompt", system_prompt])
|
||||
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
print("Evaluation complete.")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
print(f"Evaluation failed with exit code {exc.returncode}", file=sys.stderr)
|
||||
sys.exit(exc.returncode)
|
||||
|
||||
|
||||
def run_lighteval_accelerate(
|
||||
model_id: str,
|
||||
tasks: str,
|
||||
output_dir: Optional[str] = None,
|
||||
max_samples: Optional[int] = None,
|
||||
batch_size: int = 1,
|
||||
dtype: str = "bfloat16",
|
||||
trust_remote_code: bool = False,
|
||||
use_chat_template: bool = False,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run lighteval with accelerate backend for multi-GPU distributed inference.
|
||||
|
||||
Use this backend when vLLM is not available or for models not supported by vLLM.
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID
|
||||
tasks: Task specification
|
||||
output_dir: Directory for evaluation results
|
||||
max_samples: Limit number of samples per task
|
||||
batch_size: Batch size for evaluation
|
||||
dtype: Data type for model weights
|
||||
trust_remote_code: Allow executing remote code
|
||||
use_chat_template: Apply chat template
|
||||
system_prompt: System prompt for chat models
|
||||
"""
|
||||
setup_environment()
|
||||
|
||||
# Build lighteval accelerate command
|
||||
cmd = [
|
||||
"lighteval",
|
||||
"accelerate",
|
||||
model_id,
|
||||
tasks,
|
||||
"--batch-size", str(batch_size),
|
||||
"--dtype", dtype,
|
||||
]
|
||||
|
||||
if output_dir:
|
||||
cmd.extend(["--output-dir", output_dir])
|
||||
|
||||
if max_samples:
|
||||
cmd.extend(["--max-samples", str(max_samples)])
|
||||
|
||||
if trust_remote_code:
|
||||
cmd.append("--trust-remote-code")
|
||||
|
||||
if use_chat_template:
|
||||
cmd.append("--use-chat-template")
|
||||
|
||||
if system_prompt:
|
||||
cmd.extend(["--system-prompt", system_prompt])
|
||||
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
print("Evaluation complete.")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
print(f"Evaluation failed with exit code {exc.returncode}", file=sys.stderr)
|
||||
sys.exit(exc.returncode)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run lighteval evaluations with vLLM or accelerate backend on custom HuggingFace models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Run MMLU evaluation with vLLM
|
||||
uv run scripts/lighteval_vllm_uv.py --model meta-llama/Llama-3.2-1B --tasks "leaderboard|mmlu|5"
|
||||
|
||||
# Run with accelerate backend instead of vLLM
|
||||
uv run scripts/lighteval_vllm_uv.py --model meta-llama/Llama-3.2-1B --tasks "leaderboard|mmlu|5" --backend accelerate
|
||||
|
||||
# Run with chat template for instruction-tuned models
|
||||
uv run scripts/lighteval_vllm_uv.py --model meta-llama/Llama-3.2-1B-Instruct --tasks "leaderboard|mmlu|5" --use-chat-template
|
||||
|
||||
# Run with limited samples for testing
|
||||
uv run scripts/lighteval_vllm_uv.py --model meta-llama/Llama-3.2-1B --tasks "leaderboard|mmlu|5" --max-samples 10
|
||||
|
||||
Task format:
|
||||
Tasks use the format: "suite|task|num_fewshot"
|
||||
- leaderboard|mmlu|5 (MMLU with 5-shot)
|
||||
- lighteval|hellaswag|0 (HellaSwag zero-shot)
|
||||
- leaderboard|gsm8k|5 (GSM8K with 5-shot)
|
||||
- Multiple tasks: "leaderboard|mmlu|5,leaderboard|gsm8k|5"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
required=True,
|
||||
help="HuggingFace model ID (e.g., meta-llama/Llama-3.2-1B)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
required=True,
|
||||
help="Task specification (e.g., 'leaderboard|mmlu|5')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["vllm", "accelerate"],
|
||||
default="vllm",
|
||||
help="Inference backend to use (default: vllm)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
default=None,
|
||||
help="Directory for evaluation results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Limit number of samples per task (useful for testing)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size for evaluation (default: 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tensor-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for tensor parallelism (vLLM only, default: 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="GPU memory fraction to use (vLLM only, default: 0.8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="auto",
|
||||
choices=["auto", "float16", "bfloat16", "float32"],
|
||||
help="Data type for model weights (default: auto)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Allow executing remote code from model repository",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-chat-template",
|
||||
action="store_true",
|
||||
help="Apply chat template for instruction-tuned/chat models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system-prompt",
|
||||
default=None,
|
||||
help="System prompt for chat models",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.backend == "vllm":
|
||||
run_lighteval_vllm(
|
||||
model_id=args.model,
|
||||
tasks=args.tasks,
|
||||
output_dir=args.output_dir,
|
||||
max_samples=args.max_samples,
|
||||
batch_size=args.batch_size,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
dtype=args.dtype,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
use_chat_template=args.use_chat_template,
|
||||
system_prompt=args.system_prompt,
|
||||
)
|
||||
else:
|
||||
run_lighteval_accelerate(
|
||||
model_id=args.model,
|
||||
tasks=args.tasks,
|
||||
output_dir=args.output_dir,
|
||||
max_samples=args.max_samples,
|
||||
batch_size=args.batch_size,
|
||||
dtype=args.dtype if args.dtype != "auto" else "bfloat16",
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
use_chat_template=args.use_chat_template,
|
||||
system_prompt=args.system_prompt,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,127 +1,127 @@
|
||||
---
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/huggingface-datasets"
|
||||
name: hugging-face-dataset-viewer
|
||||
description: Use this skill for Hugging Face Dataset Viewer API workflows that fetch subset/split metadata, paginate rows, search text, apply filters, download parquet URLs, and read size or statistics.
|
||||
description: Query Hugging Face datasets through the Dataset Viewer API for splits, rows, search, filters, and parquet links.
|
||||
risk: unknown
|
||||
source: community
|
||||
---
|
||||
|
||||
|
||||
# Hugging Face Dataset Viewer
|
||||
|
||||
Use this skill to execute read-only Dataset Viewer API calls for dataset exploration and extraction.
|
||||
|
||||
## Core workflow
|
||||
|
||||
1. Optionally validate dataset availability with `/is-valid`.
|
||||
2. Resolve `config` + `split` with `/splits`.
|
||||
3. Preview with `/first-rows`.
|
||||
4. Paginate content with `/rows` using `offset` and `length` (max 100).
|
||||
5. Use `/search` for text matching and `/filter` for row predicates.
|
||||
6. Retrieve parquet links via `/parquet` and totals/metadata via `/size` and `/statistics`.
|
||||
|
||||
## Defaults
|
||||
|
||||
- Base URL: `https://datasets-server.huggingface.co`
|
||||
- Default API method: `GET`
|
||||
- Query params should be URL-encoded.
|
||||
- `offset` is 0-based.
|
||||
- `length` max is usually `100` for row-like endpoints.
|
||||
- Gated/private datasets require `Authorization: Bearer <HF_TOKEN>`.
|
||||
|
||||
## Dataset Viewer
|
||||
|
||||
- `Validate dataset`: `/is-valid?dataset=<namespace/repo>`
|
||||
- `List subsets and splits`: `/splits?dataset=<namespace/repo>`
|
||||
- `Preview first rows`: `/first-rows?dataset=<namespace/repo>&config=<config>&split=<split>`
|
||||
- `Paginate rows`: `/rows?dataset=<namespace/repo>&config=<config>&split=<split>&offset=<int>&length=<int>`
|
||||
- `Search text`: `/search?dataset=<namespace/repo>&config=<config>&split=<split>&query=<text>&offset=<int>&length=<int>`
|
||||
- `Filter with predicates`: `/filter?dataset=<namespace/repo>&config=<config>&split=<split>&where=<predicate>&orderby=<sort>&offset=<int>&length=<int>`
|
||||
- `List parquet shards`: `/parquet?dataset=<namespace/repo>`
|
||||
- `Get size totals`: `/size?dataset=<namespace/repo>`
|
||||
- `Get column statistics`: `/statistics?dataset=<namespace/repo>&config=<config>&split=<split>`
|
||||
- `Get Croissant metadata (if available)`: `/croissant?dataset=<namespace/repo>`
|
||||
|
||||
Pagination pattern:
|
||||
|
||||
```bash
|
||||
curl "https://datasets-server.huggingface.co/rows?dataset=stanfordnlp/imdb&config=plain_text&split=train&offset=0&length=100"
|
||||
curl "https://datasets-server.huggingface.co/rows?dataset=stanfordnlp/imdb&config=plain_text&split=train&offset=100&length=100"
|
||||
```
|
||||
|
||||
When pagination is partial, use response fields such as `num_rows_total`, `num_rows_per_page`, and `partial` to drive continuation logic.
|
||||
|
||||
Search/filter notes:
|
||||
|
||||
- `/search` matches string columns (full-text style behavior is internal to the API).
|
||||
- `/filter` requires predicate syntax in `where` and optional sort in `orderby`.
|
||||
- Keep filtering and searches read-only and side-effect free.
|
||||
|
||||
## Querying Datasets
|
||||
|
||||
Use `npx parquetlens` with Hub parquet alias paths for SQL querying.
|
||||
|
||||
Parquet alias shape:
|
||||
|
||||
```text
|
||||
hf://datasets/<namespace>/<repo>@~parquet/<config>/<split>/<shard>.parquet
|
||||
```
|
||||
|
||||
Derive `<config>`, `<split>`, and `<shard>` from Dataset Viewer `/parquet`:
|
||||
|
||||
```bash
|
||||
curl -s "https://datasets-server.huggingface.co/parquet?dataset=cfahlgren1/hub-stats" \
|
||||
| jq -r '.parquet_files[] | "hf://datasets/\(.dataset)@~parquet/\(.config)/\(.split)/\(.filename)"'
|
||||
```
|
||||
|
||||
Run SQL query:
|
||||
|
||||
```bash
|
||||
npx -y -p parquetlens -p @parquetlens/sql parquetlens \
|
||||
"hf://datasets/<namespace>/<repo>@~parquet/<config>/<split>/<shard>.parquet" \
|
||||
--sql "SELECT * FROM data LIMIT 20"
|
||||
```
|
||||
|
||||
### SQL export
|
||||
|
||||
- CSV: `--sql "COPY (SELECT * FROM data LIMIT 1000) TO 'export.csv' (FORMAT CSV, HEADER, DELIMITER ',')"`
|
||||
- JSON: `--sql "COPY (SELECT * FROM data LIMIT 1000) TO 'export.json' (FORMAT JSON)"`
|
||||
- Parquet: `--sql "COPY (SELECT * FROM data LIMIT 1000) TO 'export.parquet' (FORMAT PARQUET)"`
|
||||
|
||||
## Creating and Uploading Datasets
|
||||
|
||||
Use one of these flows depending on dependency constraints.
|
||||
|
||||
Zero local dependencies (Hub UI):
|
||||
|
||||
- Create dataset repo in browser: `https://huggingface.co/new-dataset`
|
||||
- Upload parquet files in the repo "Files and versions" page.
|
||||
- Verify shards appear in Dataset Viewer:
|
||||
|
||||
```bash
|
||||
curl -s "https://datasets-server.huggingface.co/parquet?dataset=<namespace>/<repo>"
|
||||
```
|
||||
|
||||
Low dependency CLI flow (`npx @huggingface/hub` / `hfjs`):
|
||||
|
||||
- Set auth token:
|
||||
|
||||
```bash
|
||||
export HF_TOKEN=<your_hf_token>
|
||||
```
|
||||
|
||||
- Upload parquet folder to a dataset repo (auto-creates repo if missing):
|
||||
|
||||
```bash
|
||||
npx -y @huggingface/hub upload datasets/<namespace>/<repo> ./local/parquet-folder data
|
||||
```
|
||||
|
||||
- Upload as private repo on creation:
|
||||
|
||||
```bash
|
||||
npx -y @huggingface/hub upload datasets/<namespace>/<repo> ./local/parquet-folder data --private
|
||||
```
|
||||
|
||||
After upload, call `/parquet` to discover `<config>/<split>/<shard>` values for querying with `@~parquet`.
|
||||
|
||||
|
||||
## When to Use
|
||||
Use this skill when tackling tasks related to its primary domain or functionality as described above.
|
||||
|
||||
Use this skill when you need read-only exploration of a Hugging Face dataset through the Dataset Viewer API.
|
||||
|
||||
Use this skill to execute read-only Dataset Viewer API calls for dataset exploration and extraction.
|
||||
|
||||
## Core workflow
|
||||
|
||||
1. Optionally validate dataset availability with `/is-valid`.
|
||||
2. Resolve `config` + `split` with `/splits`.
|
||||
3. Preview with `/first-rows`.
|
||||
4. Paginate content with `/rows` using `offset` and `length` (max 100).
|
||||
5. Use `/search` for text matching and `/filter` for row predicates.
|
||||
6. Retrieve parquet links via `/parquet` and totals/metadata via `/size` and `/statistics`.
|
||||
|
||||
## Defaults
|
||||
|
||||
- Base URL: `https://datasets-server.huggingface.co`
|
||||
- Default API method: `GET`
|
||||
- Query params should be URL-encoded.
|
||||
- `offset` is 0-based.
|
||||
- `length` max is usually `100` for row-like endpoints.
|
||||
- Gated/private datasets require `Authorization: Bearer <HF_TOKEN>`.
|
||||
|
||||
## Dataset Viewer
|
||||
|
||||
- `Validate dataset`: `/is-valid?dataset=<namespace/repo>`
|
||||
- `List subsets and splits`: `/splits?dataset=<namespace/repo>`
|
||||
- `Preview first rows`: `/first-rows?dataset=<namespace/repo>&config=<config>&split=<split>`
|
||||
- `Paginate rows`: `/rows?dataset=<namespace/repo>&config=<config>&split=<split>&offset=<int>&length=<int>`
|
||||
- `Search text`: `/search?dataset=<namespace/repo>&config=<config>&split=<split>&query=<text>&offset=<int>&length=<int>`
|
||||
- `Filter with predicates`: `/filter?dataset=<namespace/repo>&config=<config>&split=<split>&where=<predicate>&orderby=<sort>&offset=<int>&length=<int>`
|
||||
- `List parquet shards`: `/parquet?dataset=<namespace/repo>`
|
||||
- `Get size totals`: `/size?dataset=<namespace/repo>`
|
||||
- `Get column statistics`: `/statistics?dataset=<namespace/repo>&config=<config>&split=<split>`
|
||||
- `Get Croissant metadata (if available)`: `/croissant?dataset=<namespace/repo>`
|
||||
|
||||
Pagination pattern:
|
||||
|
||||
```bash
|
||||
curl "https://datasets-server.huggingface.co/rows?dataset=stanfordnlp/imdb&config=plain_text&split=train&offset=0&length=100"
|
||||
curl "https://datasets-server.huggingface.co/rows?dataset=stanfordnlp/imdb&config=plain_text&split=train&offset=100&length=100"
|
||||
```
|
||||
|
||||
When pagination is partial, use response fields such as `num_rows_total`, `num_rows_per_page`, and `partial` to drive continuation logic.
|
||||
|
||||
Search/filter notes:
|
||||
|
||||
- `/search` matches string columns (full-text style behavior is internal to the API).
|
||||
- `/filter` requires predicate syntax in `where` and optional sort in `orderby`.
|
||||
- Keep filtering and searches read-only and side-effect free.
|
||||
|
||||
## Querying Datasets
|
||||
|
||||
Use `npx parquetlens` with Hub parquet alias paths for SQL querying.
|
||||
|
||||
Parquet alias shape:
|
||||
|
||||
```text
|
||||
hf://datasets/<namespace>/<repo>@~parquet/<config>/<split>/<shard>.parquet
|
||||
```
|
||||
|
||||
Derive `<config>`, `<split>`, and `<shard>` from Dataset Viewer `/parquet`:
|
||||
|
||||
```bash
|
||||
curl -s "https://datasets-server.huggingface.co/parquet?dataset=cfahlgren1/hub-stats" \
|
||||
| jq -r '.parquet_files[] | "hf://datasets/\(.dataset)@~parquet/\(.config)/\(.split)/\(.filename)"'
|
||||
```
|
||||
|
||||
Run SQL query:
|
||||
|
||||
```bash
|
||||
npx -y -p parquetlens -p @parquetlens/sql parquetlens \
|
||||
"hf://datasets/<namespace>/<repo>@~parquet/<config>/<split>/<shard>.parquet" \
|
||||
--sql "SELECT * FROM data LIMIT 20"
|
||||
```
|
||||
|
||||
### SQL export
|
||||
|
||||
- CSV: `--sql "COPY (SELECT * FROM data LIMIT 1000) TO 'export.csv' (FORMAT CSV, HEADER, DELIMITER ',')"`
|
||||
- JSON: `--sql "COPY (SELECT * FROM data LIMIT 1000) TO 'export.json' (FORMAT JSON)"`
|
||||
- Parquet: `--sql "COPY (SELECT * FROM data LIMIT 1000) TO 'export.parquet' (FORMAT PARQUET)"`
|
||||
|
||||
## Creating and Uploading Datasets
|
||||
|
||||
Use one of these flows depending on dependency constraints.
|
||||
|
||||
Zero local dependencies (Hub UI):
|
||||
|
||||
- Create dataset repo in browser: `https://huggingface.co/new-dataset`
|
||||
- Upload parquet files in the repo "Files and versions" page.
|
||||
- Verify shards appear in Dataset Viewer:
|
||||
|
||||
```bash
|
||||
curl -s "https://datasets-server.huggingface.co/parquet?dataset=<namespace>/<repo>"
|
||||
```
|
||||
|
||||
Low dependency CLI flow (`npx @huggingface/hub` / `hfjs`):
|
||||
|
||||
- Set auth token:
|
||||
|
||||
```bash
|
||||
export HF_TOKEN=<your_hf_token>
|
||||
```
|
||||
|
||||
- Upload parquet folder to a dataset repo (auto-creates repo if missing):
|
||||
|
||||
```bash
|
||||
npx -y @huggingface/hub upload datasets/<namespace>/<repo> ./local/parquet-folder data
|
||||
```
|
||||
|
||||
- Upload as private repo on creation:
|
||||
|
||||
```bash
|
||||
npx -y @huggingface/hub upload datasets/<namespace>/<repo> ./local/parquet-folder data --private
|
||||
```
|
||||
|
||||
After upload, call `/parquet` to discover `<config>/<split>/<shard>` values for querying with `@~parquet`.
|
||||
|
||||
@@ -0,0 +1,304 @@
|
||||
---
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/huggingface-gradio"
|
||||
name: hugging-face-gradio
|
||||
description: Build or edit Gradio apps, layouts, components, and chat interfaces in Python.
|
||||
risk: unknown
|
||||
---
|
||||
|
||||
# Gradio
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when a user wants a Gradio demo, UI prototype, or Python-based ML interface.
|
||||
|
||||
Gradio is a Python library for building interactive web UIs and ML demos. This skill covers the core API, patterns, and examples.
|
||||
|
||||
## Guides
|
||||
|
||||
Detailed guides on specific topics (read these when relevant):
|
||||
|
||||
- [Quickstart](https://www.gradio.app/guides/quickstart)
|
||||
- [The Interface Class](https://www.gradio.app/guides/the-interface-class)
|
||||
- [Blocks and Event Listeners](https://www.gradio.app/guides/blocks-and-event-listeners)
|
||||
- [Controlling Layout](https://www.gradio.app/guides/controlling-layout)
|
||||
- [More Blocks Features](https://www.gradio.app/guides/more-blocks-features)
|
||||
- [Custom CSS and JS](https://www.gradio.app/guides/custom-CSS-and-JS)
|
||||
- [Streaming Outputs](https://www.gradio.app/guides/streaming-outputs)
|
||||
- [Streaming Inputs](https://www.gradio.app/guides/streaming-inputs)
|
||||
- [Sharing Your App](https://www.gradio.app/guides/sharing-your-app)
|
||||
- [Custom HTML Components](https://www.gradio.app/guides/custom-HTML-components)
|
||||
- [Getting Started with the Python Client](https://www.gradio.app/guides/getting-started-with-the-python-client)
|
||||
- [Getting Started with the JS Client](https://www.gradio.app/guides/getting-started-with-the-js-client)
|
||||
|
||||
## Core Patterns
|
||||
|
||||
**Interface** (high-level): wraps a function with input/output components.
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
def greet(name):
|
||||
return f"Hello {name}!"
|
||||
|
||||
gr.Interface(fn=greet, inputs="text", outputs="text").launch()
|
||||
```
|
||||
|
||||
**Blocks** (low-level): flexible layout with explicit event wiring.
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
name = gr.Textbox(label="Name")
|
||||
output = gr.Textbox(label="Greeting")
|
||||
btn = gr.Button("Greet")
|
||||
btn.click(fn=lambda n: f"Hello {n}!", inputs=name, outputs=output)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
**ChatInterface**: high-level wrapper for chatbot UIs.
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
def respond(message, history):
|
||||
return f"You said: {message}"
|
||||
|
||||
gr.ChatInterface(fn=respond).launch()
|
||||
```
|
||||
|
||||
## Key Component Signatures
|
||||
|
||||
### `Textbox(value: str | I18nData | Callable | None = None, type: Literal['text', 'password', 'email'] = "text", lines: int = 1, max_lines: int | None = None, placeholder: str | I18nData | None = None, label: str | I18nData | None = None, info: str | I18nData | None = None, every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, show_label: bool | None = None, container: bool = True, scale: int | None = None, min_width: int = 160, interactive: bool | None = None, visible: bool | Literal['hidden'] = True, elem_id: str | None = None, autofocus: bool = False, autoscroll: bool = True, elem_classes: list[str] | str | None = None, render: bool = True, key: int | str | tuple[int | str, ...] | None = None, preserved_by_key: list[str] | str | None = "value", text_align: Literal['left', 'right'] | None = None, rtl: bool = False, buttons: list[Literal['copy'] | Button] | None = None, max_length: int | None = None, submit_btn: str | bool | None = False, stop_btn: str | bool | None = False, html_attributes: InputHTMLAttributes | None = None)`
|
||||
Creates a textarea for user to enter string input or display string output..
|
||||
|
||||
### `Number(value: float | Callable | None = None, label: str | I18nData | None = None, placeholder: str | I18nData | None = None, info: str | I18nData | None = None, every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, show_label: bool | None = None, container: bool = True, scale: int | None = None, min_width: int = 160, interactive: bool | None = None, visible: bool | Literal['hidden'] = True, elem_id: str | None = None, elem_classes: list[str] | str | None = None, render: bool = True, key: int | str | tuple[int | str, ...] | None = None, preserved_by_key: list[str] | str | None = "value", buttons: list[Button] | None = None, precision: int | None = None, minimum: float | None = None, maximum: float | None = None, step: float = 1)`
|
||||
Creates a numeric field for user to enter numbers as input or display numeric output..
|
||||
|
||||
### `Slider(minimum: float = 0, maximum: float = 100, value: float | Callable | None = None, step: float | None = None, precision: int | None = None, label: str | I18nData | None = None, info: str | I18nData | None = None, every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, show_label: bool | None = None, container: bool = True, scale: int | None = None, min_width: int = 160, interactive: bool | None = None, visible: bool | Literal['hidden'] = True, elem_id: str | None = None, elem_classes: list[str] | str | None = None, render: bool = True, key: int | str | tuple[int | str, ...] | None = None, preserved_by_key: list[str] | str | None = "value", randomize: bool = False, buttons: list[Literal['reset']] | None = None)`
|
||||
Creates a slider that ranges from {minimum} to {maximum} with a step size of {step}..
|
||||
|
||||
### `Checkbox(value: bool | Callable = False, label: str | I18nData | None = None, info: str | I18nData | None = None, every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, show_label: bool | None = None, container: bool = True, scale: int | None = None, min_width: int = 160, interactive: bool | None = None, visible: bool | Literal['hidden'] = True, elem_id: str | None = None, elem_classes: list[str] | str | None = None, render: bool = True, key: int | str | tuple[int | str, ...] | None = None, preserved_by_key: list[str] | str | None = "value", buttons: list[Button] | None = None)`
|
||||
Creates a checkbox that can be set to `True` or `False`.
|
||||
|
||||
### `Dropdown(choices: Sequence[str | int | float | tuple[str, str | int | float]] | None = None, value: str | int | float | Sequence[str | int | float] | Callable | DefaultValue | None = DefaultValue(), type: Literal['value', 'index'] = "value", multiselect: bool | None = None, allow_custom_value: bool = False, max_choices: int | None = None, filterable: bool = True, label: str | I18nData | None = None, info: str | I18nData | None = None, every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, show_label: bool | None = None, container: bool = True, scale: int | None = None, min_width: int = 160, interactive: bool | None = None, visible: bool | Literal['hidden'] = True, elem_id: str | None = None, elem_classes: list[str] | str | None = None, render: bool = True, key: int | str | tuple[int | str, ...] | None = None, preserved_by_key: list[str] | str | None = "value", buttons: list[Button] | None = None)`
|
||||
Creates a dropdown of choices from which a single entry or multiple entries can be selected (as an input component) or displayed (as an output component)..
|
||||
|
||||
### `Radio(choices: Sequence[str | int | float | tuple[str, str | int | float]] | None = None, value: str | int | float | Callable | None = None, type: Literal['value', 'index'] = "value", label: str | I18nData | None = None, info: str | I18nData | None = None, every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, show_label: bool | None = None, container: bool = True, scale: int | None = None, min_width: int = 160, interactive: bool | None = None, visible: bool | Literal['hidden'] = True, elem_id: str | None = None, elem_classes: list[str] | str | None = None, render: bool = True, key: int | str | tuple[int | str, ...] | None = None, preserved_by_key: list[str] | str | None = "value", rtl: bool = False, buttons: list[Button] | None = None)`
|
||||
Creates a set of (string or numeric type) radio buttons of which only one can be selected..
|
||||
|
||||
### `Image(value: str | PIL.Image.Image | np.ndarray | Callable | None = None, format: str = "webp", height: int | str | None = None, width: int | str | None = None, image_mode: Literal['1', 'L', 'P', 'RGB', 'RGBA', 'CMYK', 'YCbCr', 'LAB', 'HSV', 'I', 'F'] | None = "RGB", sources: list[Literal['upload', 'webcam', 'clipboard']] | Literal['upload', 'webcam', 'clipboard'] | None = None, type: Literal['numpy', 'pil', 'filepath'] = "numpy", label: str | I18nData | None = None, every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, show_label: bool | None = None, buttons: list[Literal['download', 'share', 'fullscreen'] | Button] | None = None, container: bool = True, scale: int | None = None, min_width: int = 160, interactive: bool | None = None, visible: bool | Literal['hidden'] = True, streaming: bool = False, elem_id: str | None = None, elem_classes: list[str] | str | None = None, render: bool = True, key: int | str | tuple[int | str, ...] | None = None, preserved_by_key: list[str] | str | None = "value", webcam_options: WebcamOptions | None = None, placeholder: str | None = None, watermark: WatermarkOptions | None = None)`
|
||||
Creates an image component that can be used to upload images (as an input) or display images (as an output)..
|
||||
|
||||
### `Audio(value: str | Path | tuple[int, np.ndarray] | Callable | None = None, sources: list[Literal['upload', 'microphone']] | Literal['upload', 'microphone'] | None = None, type: Literal['numpy', 'filepath'] = "numpy", label: str | I18nData | None = None, every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, show_label: bool | None = None, container: bool = True, scale: int | None = None, min_width: int = 160, interactive: bool | None = None, visible: bool | Literal['hidden'] = True, streaming: bool = False, elem_id: str | None = None, elem_classes: list[str] | str | None = None, render: bool = True, key: int | str | tuple[int | str, ...] | None = None, preserved_by_key: list[str] | str | None = "value", format: Literal['wav', 'mp3'] | None = None, autoplay: bool = False, editable: bool = True, buttons: list[Literal['download', 'share'] | Button] | None = None, waveform_options: WaveformOptions | dict | None = None, loop: bool = False, recording: bool = False, subtitles: str | Path | list[dict[str, Any]] | None = None, playback_position: float = 0)`
|
||||
Creates an audio component that can be used to upload/record audio (as an input) or display audio (as an output)..
|
||||
|
||||
### `Video(value: str | Path | Callable | None = None, format: str | None = None, sources: list[Literal['upload', 'webcam']] | Literal['upload', 'webcam'] | None = None, height: int | str | None = None, width: int | str | None = None, label: str | I18nData | None = None, every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, show_label: bool | None = None, container: bool = True, scale: int | None = None, min_width: int = 160, interactive: bool | None = None, visible: bool | Literal['hidden'] = True, elem_id: str | None = None, elem_classes: list[str] | str | None = None, render: bool = True, key: int | str | tuple[int | str, ...] | None = None, preserved_by_key: list[str] | str | None = "value", webcam_options: WebcamOptions | None = None, include_audio: bool | None = None, autoplay: bool = False, buttons: list[Literal['download', 'share'] | Button] | None = None, loop: bool = False, streaming: bool = False, watermark: WatermarkOptions | None = None, subtitles: str | Path | list[dict[str, Any]] | None = None, playback_position: float = 0)`
|
||||
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
|
||||
|
||||
### `File(value: str | list[str] | Callable | None = None, file_count: Literal['single', 'multiple', 'directory'] = "single", file_types: list[str] | None = None, type: Literal['filepath', 'binary'] = "filepath", label: str | I18nData | None = None, every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, show_label: bool | None = None, container: bool = True, scale: int | None = None, min_width: int = 160, height: int | str | float | None = None, interactive: bool | None = None, visible: bool | Literal['hidden'] = True, elem_id: str | None = None, elem_classes: list[str] | str | None = None, render: bool = True, key: int | str | tuple[int | str, ...] | None = None, preserved_by_key: list[str] | str | None = "value", allow_reordering: bool = False, buttons: list[Button] | None = None)`
|
||||
Creates a file component that allows uploading one or more generic files (when used as an input) or displaying generic files or URLs for download (as output).
|
||||
|
||||
### `Chatbot(value: list[MessageDict | Message] | Callable | None = None, label: str | I18nData | None = None, every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, show_label: bool | None = None, container: bool = True, scale: int | None = None, min_width: int = 160, visible: bool | Literal['hidden'] = True, elem_id: str | None = None, elem_classes: list[str] | str | None = None, autoscroll: bool = True, render: bool = True, key: int | str | tuple[int | str, ...] | None = None, preserved_by_key: list[str] | str | None = "value", height: int | str | None = 400, resizable: bool = False, max_height: int | str | None = None, min_height: int | str | None = None, editable: Literal['user', 'all'] | None = None, latex_delimiters: list[dict[str, str | bool]] | None = None, rtl: bool = False, buttons: list[Literal['share', 'copy', 'copy_all'] | Button] | None = None, watermark: str | None = None, avatar_images: tuple[str | Path | None, str | Path | None] | None = None, sanitize_html: bool = True, render_markdown: bool = True, feedback_options: list[str] | tuple[str, ...] | None = ('Like', 'Dislike'), feedback_value: Sequence[str | None] | None = None, line_breaks: bool = True, layout: Literal['panel', 'bubble'] | None = None, placeholder: str | None = None, examples: list[ExampleMessage] | None = None, allow_file_downloads: <class 'inspect._empty'> = True, group_consecutive_messages: bool = True, allow_tags: list[str] | bool = True, reasoning_tags: list[tuple[str, str]] | None = None, like_user_message: bool = False)`
|
||||
Creates a chatbot that displays user-submitted messages and responses.
|
||||
|
||||
### `Button(value: str | I18nData | Callable = "Run", every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, variant: Literal['primary', 'secondary', 'stop', 'huggingface'] = "secondary", size: Literal['sm', 'md', 'lg'] = "lg", icon: str | Path | None = None, link: str | None = None, link_target: Literal['_self', '_blank', '_parent', '_top'] = "_self", visible: bool | Literal['hidden'] = True, interactive: bool = True, elem_id: str | None = None, elem_classes: list[str] | str | None = None, render: bool = True, key: int | str | tuple[int | str, ...] | None = None, preserved_by_key: list[str] | str | None = "value", scale: int | None = None, min_width: int | None = None)`
|
||||
Creates a button that can be assigned arbitrary .click() events.
|
||||
|
||||
### `Markdown(value: str | I18nData | Callable | None = None, label: str | I18nData | None = None, every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, show_label: bool | None = None, rtl: bool = False, latex_delimiters: list[dict[str, str | bool]] | None = None, visible: bool | Literal['hidden'] = True, elem_id: str | None = None, elem_classes: list[str] | str | None = None, render: bool = True, key: int | str | tuple[int | str, ...] | None = None, preserved_by_key: list[str] | str | None = "value", sanitize_html: bool = True, line_breaks: bool = False, header_links: bool = False, height: int | str | None = None, max_height: int | str | None = None, min_height: int | str | None = None, buttons: list[Literal['copy']] | None = None, container: bool = False, padding: bool = False)`
|
||||
Used to render arbitrary Markdown output.
|
||||
|
||||
### `HTML(value: Any | Callable | None = None, label: str | I18nData | None = None, html_template: str = "${value}", css_template: str = "", js_on_load: str | None = "element.addEventListener('click', function() { trigger('click') });", apply_default_css: bool = True, every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, show_label: bool = False, visible: bool | Literal['hidden'] = True, elem_id: str | None = None, elem_classes: list[str] | str | None = None, render: bool = True, key: int | str | tuple[int | str, ...] | None = None, preserved_by_key: list[str] | str | None = "value", min_height: int | None = None, max_height: int | None = None, container: bool = False, padding: bool = False, autoscroll: bool = False, buttons: list[Button] | None = None, server_functions: list[Callable] | None = None, props: Any)`
|
||||
Creates a component with arbitrary HTML.
|
||||
|
||||
|
||||
## Custom HTML Components
|
||||
|
||||
If a task requires significant customization of an existing component or a component that doesn't exist in Gradio, you can create one with `gr.HTML`. It supports `html_template` (with `${}` JS expressions and `{{}}` Handlebars syntax), `css_template` for scoped styles, and `js_on_load` for interactivity — where `props.value` updates the component value and `trigger('event_name')` fires Gradio events. For reuse, subclass `gr.HTML` and define `api_info()` for API/MCP support. See the [full guide](https://www.gradio.app/guides/custom-HTML-components).
|
||||
|
||||
Here's an example that shows how to create and use these kinds of components:
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
class StarRating(gr.HTML):
|
||||
def __init__(self, label, value=0, **kwargs):
|
||||
html_template = """
|
||||
<h2>${label} rating:</h2>
|
||||
${Array.from({length: 5}, (_, i) => `<img class='${i < value ? '' : 'faded'}' src='https://upload.wikimedia.org/wikipedia/commons/d/df/Award-star-gold-3d.svg'>`).join('')}
|
||||
"""
|
||||
css_template = """
|
||||
img { height: 50px; display: inline-block; cursor: pointer; }
|
||||
.faded { filter: grayscale(100%); opacity: 0.3; }
|
||||
"""
|
||||
js_on_load = """
|
||||
const imgs = element.querySelectorAll('img');
|
||||
imgs.forEach((img, index) => {
|
||||
img.addEventListener('click', () => {
|
||||
props.value = index + 1;
|
||||
});
|
||||
});
|
||||
"""
|
||||
super().__init__(value=value, label=label, html_template=html_template, css_template=css_template, js_on_load=js_on_load, **kwargs)
|
||||
|
||||
def api_info(self):
|
||||
return {"type": "integer", "minimum": 0, "maximum": 5}
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("# Restaurant Review")
|
||||
food_rating = StarRating(label="Food", value=3)
|
||||
service_rating = StarRating(label="Service", value=3)
|
||||
ambience_rating = StarRating(label="Ambience", value=3)
|
||||
average_btn = gr.Button("Calculate Average Rating")
|
||||
rating_output = StarRating(label="Average", value=3)
|
||||
def calculate_average(food, service, ambience):
|
||||
return round((food + service + ambience) / 3)
|
||||
average_btn.click(
|
||||
fn=calculate_average,
|
||||
inputs=[food_rating, service_rating, ambience_rating],
|
||||
outputs=rating_output
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Event Listeners
|
||||
|
||||
All event listeners share the same signature:
|
||||
|
||||
```python
|
||||
component.event_name(
|
||||
fn: Callable | None | Literal["decorator"] = "decorator",
|
||||
inputs: Component | Sequence[Component] | set[Component] | None = None,
|
||||
outputs: Component | Sequence[Component] | set[Component] | None = None,
|
||||
api_name: str | None = None,
|
||||
api_description: str | None | Literal[False] = None,
|
||||
scroll_to_output: bool = False,
|
||||
show_progress: Literal["full", "minimal", "hidden"] = "full",
|
||||
show_progress_on: Component | Sequence[Component] | None = None,
|
||||
queue: bool = True,
|
||||
batch: bool = False,
|
||||
max_batch_size: int = 4,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
cancels: dict[str, Any] | list[dict[str, Any]] | None = None,
|
||||
trigger_mode: Literal["once", "multiple", "always_last"] | None = None,
|
||||
js: str | Literal[True] | None = None,
|
||||
concurrency_limit: int | None | Literal["default"] = "default",
|
||||
concurrency_id: str | None = None,
|
||||
api_visibility: Literal["public", "private", "undocumented"] = "public",
|
||||
time_limit: int | None = None,
|
||||
stream_every: float = 0.5,
|
||||
key: int | str | tuple[int | str, ...] | None = None,
|
||||
validator: Callable | None = None,
|
||||
) -> Dependency
|
||||
```
|
||||
|
||||
Supported events per component:
|
||||
|
||||
- **AnnotatedImage**: select
|
||||
- **Audio**: stream, change, clear, play, pause, stop, pause, start_recording, pause_recording, stop_recording, upload, input
|
||||
- **BarPlot**: select, double_click
|
||||
- **BrowserState**: change
|
||||
- **Button**: click
|
||||
- **Chatbot**: change, select, like, retry, undo, example_select, option_select, clear, copy, edit
|
||||
- **Checkbox**: change, input, select
|
||||
- **CheckboxGroup**: change, input, select
|
||||
- **ClearButton**: click
|
||||
- **Code**: change, input, focus, blur
|
||||
- **ColorPicker**: change, input, submit, focus, blur
|
||||
- **Dataframe**: change, input, select, edit
|
||||
- **Dataset**: click, select
|
||||
- **DateTime**: change, submit
|
||||
- **DeepLinkButton**: click
|
||||
- **Dialogue**: change, input, submit
|
||||
- **DownloadButton**: click
|
||||
- **Dropdown**: change, input, select, focus, blur, key_up
|
||||
- **DuplicateButton**: click
|
||||
- **File**: change, select, clear, upload, delete, download
|
||||
- **FileExplorer**: change, input, select
|
||||
- **Gallery**: select, upload, change, delete, preview_close, preview_open
|
||||
- **HTML**: change, input, click, double_click, submit, stop, edit, clear, play, pause, end, start_recording, pause_recording, stop_recording, focus, blur, upload, release, select, stream, like, example_select, option_select, load, key_up, apply, delete, tick, undo, retry, expand, collapse, download, copy
|
||||
- **HighlightedText**: change, select
|
||||
- **Image**: clear, change, stream, select, upload, input
|
||||
- **ImageEditor**: clear, change, input, select, upload, apply
|
||||
- **ImageSlider**: clear, change, stream, select, upload, input
|
||||
- **JSON**: change
|
||||
- **Label**: change, select
|
||||
- **LinePlot**: select, double_click
|
||||
- **LoginButton**: click
|
||||
- **Markdown**: change, copy
|
||||
- **Model3D**: change, upload, edit, clear
|
||||
- **MultimodalTextbox**: change, input, select, submit, focus, blur, stop
|
||||
- **Navbar**: change
|
||||
- **Number**: change, input, submit, focus, blur
|
||||
- **ParamViewer**: change, upload
|
||||
- **Plot**: change
|
||||
- **Radio**: select, change, input
|
||||
- **ScatterPlot**: select, double_click
|
||||
- **SimpleImage**: clear, change, upload
|
||||
- **Slider**: change, input, release
|
||||
- **State**: change
|
||||
- **Textbox**: change, input, select, submit, focus, blur, stop, copy
|
||||
- **Timer**: tick
|
||||
- **UploadButton**: click, upload
|
||||
- **Video**: change, clear, start_recording, stop_recording, stop, play, pause, end, upload, input
|
||||
|
||||
## Prediction CLI
|
||||
|
||||
The `gradio` CLI includes `info` and `predict` commands for interacting with Gradio apps programmatically. These are especially useful for coding agents that need to use Spaces in their workflows.
|
||||
|
||||
### `gradio info` — Discover endpoints and parameters
|
||||
|
||||
```bash
|
||||
gradio info <space_id_or_url>
|
||||
```
|
||||
|
||||
Returns a JSON payload describing all endpoints, their parameters (with types and defaults), and return values.
|
||||
|
||||
```bash
|
||||
gradio info gradio/calculator
|
||||
# {
|
||||
# "/predict": {
|
||||
# "parameters": [
|
||||
# {"name": "num1", "required": true, "default": null, "type": {"type": "number"}},
|
||||
# {"name": "operation", "required": true, "default": null, "type": {"enum": ["add", "subtract", "multiply", "divide"], "type": "string"}},
|
||||
# {"name": "num2", "required": true, "default": null, "type": {"type": "number"}}
|
||||
# ],
|
||||
# "returns": [{"name": "output", "type": {"type": "number"}}],
|
||||
# "description": ""
|
||||
# }
|
||||
# }
|
||||
```
|
||||
|
||||
File-type parameters show `"type": "filepath"` with instructions to include `"meta": {"_type": "gradio.FileData"}` — this signals the file will be uploaded to the remote server.
|
||||
|
||||
### `gradio predict` — Send predictions
|
||||
|
||||
```bash
|
||||
gradio predict <space_id_or_url> <endpoint> <json_payload>
|
||||
```
|
||||
|
||||
Returns a JSON object with named output keys.
|
||||
|
||||
```bash
|
||||
# Simple numeric prediction
|
||||
gradio predict gradio/calculator /predict '{"num1": 5, "operation": "multiply", "num2": 3}'
|
||||
# {"output": 15}
|
||||
|
||||
# Image generation
|
||||
gradio predict black-forest-labs/FLUX.2-dev /infer '{"prompt": "A majestic dragon"}'
|
||||
# {"Result": "/tmp/gradio/.../image.webp", "Seed": 1117868604}
|
||||
|
||||
# File upload (must include meta key)
|
||||
gradio predict gradio/image_mod /predict '{"image": {"path": "/path/to/image.png", "meta": {"_type": "gradio.FileData"}}}'
|
||||
# {"output": "/tmp/gradio/.../output.png"}
|
||||
```
|
||||
|
||||
Both commands accept `--token` for accessing private Spaces.
|
||||
|
||||
## Additional Reference
|
||||
|
||||
- [End-to-End Examples](examples.md) — complete working apps
|
||||
@@ -0,0 +1,613 @@
|
||||
# Gradio End-to-End Examples
|
||||
|
||||
Complete working Gradio apps for reference.
|
||||
|
||||
## Blocks Essay Simple
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
def change_textbox(choice):
|
||||
if choice == "short":
|
||||
return gr.Textbox(lines=2, visible=True)
|
||||
elif choice == "long":
|
||||
return gr.Textbox(lines=8, visible=True, value="Lorem ipsum dolor sit amet")
|
||||
else:
|
||||
return gr.Textbox(visible=False)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
radio = gr.Radio(
|
||||
["short", "long", "none"], label="What kind of essay would you like to write?"
|
||||
)
|
||||
text = gr.Textbox(lines=2, interactive=True, buttons=["copy"])
|
||||
radio.change(fn=change_textbox, inputs=radio, outputs=text)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Blocks Flipper
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import gradio as gr
|
||||
|
||||
def flip_text(x):
|
||||
return x[::-1]
|
||||
|
||||
def flip_image(x):
|
||||
return np.fliplr(x)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("Flip text or image files using this demo.")
|
||||
with gr.Tab("Flip Text"):
|
||||
text_input = gr.Textbox()
|
||||
text_output = gr.Textbox()
|
||||
text_button = gr.Button("Flip")
|
||||
with gr.Tab("Flip Image"):
|
||||
with gr.Row():
|
||||
image_input = gr.Image()
|
||||
image_output = gr.Image()
|
||||
image_button = gr.Button("Flip")
|
||||
|
||||
with gr.Accordion("Open for More!", open=False):
|
||||
gr.Markdown("Look at me...")
|
||||
temp_slider = gr.Slider(
|
||||
0, 1,
|
||||
value=0.1,
|
||||
step=0.1,
|
||||
interactive=True,
|
||||
label="Slide me",
|
||||
)
|
||||
|
||||
text_button.click(flip_text, inputs=text_input, outputs=text_output)
|
||||
image_button.click(flip_image, inputs=image_input, outputs=image_output)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Blocks Form
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
name_box = gr.Textbox(label="Name")
|
||||
age_box = gr.Number(label="Age", minimum=0, maximum=100)
|
||||
symptoms_box = gr.CheckboxGroup(["Cough", "Fever", "Runny Nose"])
|
||||
submit_btn = gr.Button("Submit")
|
||||
|
||||
with gr.Column(visible=False) as output_col:
|
||||
diagnosis_box = gr.Textbox(label="Diagnosis")
|
||||
patient_summary_box = gr.Textbox(label="Patient Summary")
|
||||
|
||||
def submit(name, age, symptoms):
|
||||
return {
|
||||
submit_btn: gr.Button(visible=False),
|
||||
output_col: gr.Column(visible=True),
|
||||
diagnosis_box: "covid" if "Cough" in symptoms else "flu",
|
||||
patient_summary_box: f"{name}, {age} y/o",
|
||||
}
|
||||
|
||||
submit_btn.click(
|
||||
submit,
|
||||
[name_box, age_box, symptoms_box],
|
||||
[submit_btn, diagnosis_box, patient_summary_box, output_col],
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Blocks Hello
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
def welcome(name):
|
||||
return f"Welcome to Gradio, {name}!"
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# Hello World!
|
||||
Start typing below to see the output.
|
||||
""")
|
||||
inp = gr.Textbox(placeholder="What is your name?")
|
||||
out = gr.Textbox()
|
||||
inp.change(welcome, inp, out)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Blocks Layout
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
demo = gr.Blocks()
|
||||
|
||||
with demo:
|
||||
with gr.Row():
|
||||
gr.Image(interactive=True, scale=2)
|
||||
gr.Image()
|
||||
with gr.Row():
|
||||
gr.Textbox(label="Text")
|
||||
gr.Number(label="Count", scale=2)
|
||||
gr.Radio(choices=["One", "Two"])
|
||||
with gr.Row():
|
||||
gr.Button("500", scale=0, min_width=500)
|
||||
gr.Button("A", scale=0)
|
||||
gr.Button("grow")
|
||||
with gr.Row():
|
||||
gr.Textbox()
|
||||
gr.Textbox()
|
||||
gr.Button()
|
||||
with gr.Row():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Textbox(label="Text")
|
||||
gr.Number(label="Count")
|
||||
gr.Radio(choices=["One", "Two"])
|
||||
gr.Image()
|
||||
with gr.Column():
|
||||
gr.Image(interactive=True)
|
||||
gr.Image()
|
||||
gr.Image()
|
||||
gr.Textbox(label="Text")
|
||||
gr.Number(label="Count")
|
||||
gr.Radio(choices=["One", "Two"])
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Calculator
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
def calculator(num1, operation, num2):
|
||||
if operation == "add":
|
||||
return num1 + num2
|
||||
elif operation == "subtract":
|
||||
return num1 - num2
|
||||
elif operation == "multiply":
|
||||
return num1 * num2
|
||||
elif operation == "divide":
|
||||
if num2 == 0:
|
||||
raise gr.Error("Cannot divide by zero!")
|
||||
return num1 / num2
|
||||
|
||||
demo = gr.Interface(
|
||||
calculator,
|
||||
[
|
||||
"number",
|
||||
gr.Radio(["add", "subtract", "multiply", "divide"]),
|
||||
"number"
|
||||
],
|
||||
"number",
|
||||
examples=[
|
||||
[45, "add", 3],
|
||||
[3.14, "divide", 2],
|
||||
[144, "multiply", 2.5],
|
||||
[0, "subtract", 1.2],
|
||||
],
|
||||
title="Toy Calculator",
|
||||
description="Here's a sample toy calculator.",
|
||||
api_name="predict"
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Chatbot Simple
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import random
|
||||
import time
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
chatbot = gr.Chatbot()
|
||||
msg = gr.Textbox()
|
||||
clear = gr.ClearButton([msg, chatbot])
|
||||
|
||||
def respond(message, chat_history):
|
||||
bot_message = random.choice(["How are you?", "Today is a great day", "I'm very hungry"])
|
||||
chat_history.append({"role": "user", "content": message})
|
||||
chat_history.append({"role": "assistant", "content": bot_message})
|
||||
time.sleep(2)
|
||||
return "", chat_history
|
||||
|
||||
msg.submit(respond, [msg, chatbot], [msg, chatbot])
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Chatbot Streaming
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import random
|
||||
import time
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
chatbot = gr.Chatbot()
|
||||
msg = gr.Textbox()
|
||||
clear = gr.Button("Clear")
|
||||
|
||||
def user(user_message, history: list):
|
||||
return "", history + [{"role": "user", "content": user_message}]
|
||||
|
||||
def bot(history: list):
|
||||
bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
|
||||
history.append({"role": "assistant", "content": ""})
|
||||
for character in bot_message:
|
||||
history[-1]['content'] += character
|
||||
time.sleep(0.05)
|
||||
yield history
|
||||
|
||||
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
||||
bot, chatbot, chatbot
|
||||
)
|
||||
clear.click(lambda: None, None, chatbot, queue=False)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Custom Css
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Column(elem_classes="cool-col"):
|
||||
gr.Markdown("### Gradio Demo with Custom CSS", elem_classes="darktest")
|
||||
gr.Markdown(
|
||||
elem_classes="markdown",
|
||||
value="Resize the browser window to see the CSS media query in action.",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(css_paths=["demo/custom_css/custom_css.css"])
|
||||
```
|
||||
|
||||
## Fake Diffusion
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
def fake_diffusion(steps):
|
||||
rng = np.random.default_rng()
|
||||
for i in range(steps):
|
||||
time.sleep(1)
|
||||
image = rng.random(size=(600, 600, 3))
|
||||
yield image
|
||||
image = np.ones((1000,1000,3), np.uint8)
|
||||
image[:] = [255, 124, 0]
|
||||
yield image
|
||||
|
||||
demo = gr.Interface(fake_diffusion,
|
||||
inputs=gr.Slider(1, 10, 3, step=1),
|
||||
outputs="image",
|
||||
api_name="predict")
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Hello World
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def greet(name):
|
||||
return "Hello " + name + "!"
|
||||
|
||||
|
||||
demo = gr.Interface(fn=greet, inputs="textbox", outputs="textbox", api_name="predict")
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Image Editor
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
|
||||
def sleep(im):
|
||||
time.sleep(5)
|
||||
return [im["background"], im["layers"][0], im["layers"][1], im["composite"]]
|
||||
|
||||
|
||||
def predict(im):
|
||||
return im["composite"]
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Row():
|
||||
im = gr.ImageEditor(
|
||||
type="numpy",
|
||||
)
|
||||
im_preview = gr.Image()
|
||||
n_upload = gr.Number(0, label="Number of upload events", step=1)
|
||||
n_change = gr.Number(0, label="Number of change events", step=1)
|
||||
n_input = gr.Number(0, label="Number of input events", step=1)
|
||||
|
||||
im.upload(lambda x: x + 1, outputs=n_upload, inputs=n_upload)
|
||||
im.change(lambda x: x + 1, outputs=n_change, inputs=n_change)
|
||||
im.input(lambda x: x + 1, outputs=n_input, inputs=n_input)
|
||||
im.change(predict, outputs=im_preview, inputs=im, show_progress="hidden")
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## On Listener Decorator
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
name = gr.Textbox(label="Name")
|
||||
output = gr.Textbox(label="Output Box")
|
||||
greet_btn = gr.Button("Greet")
|
||||
|
||||
@gr.on(triggers=[name.submit, greet_btn.click], inputs=name, outputs=output)
|
||||
def greet(name):
|
||||
return "Hello " + name + "!"
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Render Merge
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
text_count = gr.Slider(1, 5, value=1, step=1, label="Textbox Count")
|
||||
|
||||
@gr.render(inputs=text_count)
|
||||
def render_count(count):
|
||||
boxes = []
|
||||
for i in range(count):
|
||||
box = gr.Textbox(label=f"Box {i}")
|
||||
boxes.append(box)
|
||||
|
||||
def merge(*args):
|
||||
time.sleep(0.2) # simulate a delay
|
||||
return " ".join(args)
|
||||
|
||||
merge_btn.click(merge, boxes, output)
|
||||
|
||||
def clear():
|
||||
time.sleep(0.2) # simulate a delay
|
||||
return [" "] * count
|
||||
|
||||
clear_btn.click(clear, None, boxes)
|
||||
|
||||
def countup():
|
||||
time.sleep(0.2) # simulate a delay
|
||||
return list(range(count))
|
||||
|
||||
count_btn.click(countup, None, boxes, queue=False)
|
||||
|
||||
with gr.Row():
|
||||
merge_btn = gr.Button("Merge")
|
||||
clear_btn = gr.Button("Clear")
|
||||
count_btn = gr.Button("Count")
|
||||
|
||||
output = gr.Textbox()
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Reverse Audio 2
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
def reverse_audio(audio):
|
||||
sr, data = audio
|
||||
return (sr, np.flipud(data))
|
||||
|
||||
demo = gr.Interface(fn=reverse_audio,
|
||||
inputs="microphone",
|
||||
outputs="audio", api_name="predict")
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Sepia Filter
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import gradio as gr
|
||||
|
||||
def sepia(input_img):
|
||||
sepia_filter = np.array([
|
||||
[0.393, 0.769, 0.189],
|
||||
[0.349, 0.686, 0.168],
|
||||
[0.272, 0.534, 0.131]
|
||||
])
|
||||
sepia_img = input_img.dot(sepia_filter.T)
|
||||
sepia_img /= sepia_img.max()
|
||||
return sepia_img
|
||||
|
||||
demo = gr.Interface(sepia, gr.Image(), "image", api_name="predict")
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Sort Records
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
def sort_records(records):
|
||||
return records.sort("Quantity")
|
||||
|
||||
demo = gr.Interface(
|
||||
sort_records,
|
||||
gr.Dataframe(
|
||||
headers=["Item", "Quantity"],
|
||||
datatype=["str", "number"],
|
||||
row_count=3,
|
||||
column_count=2,
|
||||
column_limits=(2, 2),
|
||||
type="polars"
|
||||
),
|
||||
"dataframe",
|
||||
description="Sort by Quantity"
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Streaming Simple
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
input_img = gr.Image(label="Input", sources="webcam")
|
||||
with gr.Column():
|
||||
output_img = gr.Image(label="Output")
|
||||
input_img.stream(lambda s: s, input_img, output_img, time_limit=15, stream_every=0.1, concurrency_limit=30)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Tabbed Interface Lite
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
hello_world = gr.Interface(lambda name: "Hello " + name, "text", "text", api_name="predict")
|
||||
bye_world = gr.Interface(lambda name: "Bye " + name, "text", "text", api_name="predict")
|
||||
chat = gr.ChatInterface(lambda *args: "Hello " + args[0], api_name="chat")
|
||||
|
||||
demo = gr.TabbedInterface([hello_world, bye_world, chat], ["Hello World", "Bye World", "Chat"])
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Tax Calculator
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
def tax_calculator(income, marital_status, assets):
|
||||
tax_brackets = [(10, 0), (25, 8), (60, 12), (120, 20), (250, 30)]
|
||||
total_deductible = sum(cost for cost, deductible in zip(assets["Cost"], assets["Deductible"]) if deductible)
|
||||
taxable_income = income - total_deductible
|
||||
|
||||
total_tax = 0
|
||||
for bracket, rate in tax_brackets:
|
||||
if taxable_income > bracket:
|
||||
total_tax += (taxable_income - bracket) * rate / 100
|
||||
|
||||
if marital_status == "Married":
|
||||
total_tax *= 0.75
|
||||
elif marital_status == "Divorced":
|
||||
total_tax *= 0.8
|
||||
|
||||
return round(total_tax)
|
||||
|
||||
demo = gr.Interface(
|
||||
tax_calculator,
|
||||
[
|
||||
"number",
|
||||
gr.Radio(["Single", "Married", "Divorced"]),
|
||||
gr.Dataframe(
|
||||
headers=["Item", "Cost", "Deductible"],
|
||||
datatype=["str", "number", "bool"],
|
||||
label="Assets Purchased this Year",
|
||||
),
|
||||
],
|
||||
gr.Number(label="Tax due"),
|
||||
examples=[
|
||||
[10000, "Married", [["Suit", 5000, True], ["Laptop (for work)", 800, False], ["Car", 1800, True]]],
|
||||
[80000, "Single", [["Suit", 800, True], ["Watch", 1800, True], ["Food", 800, True]]],
|
||||
],
|
||||
live=True,
|
||||
api_name="predict"
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Timer Simple
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import random
|
||||
import time
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
timer = gr.Timer(1)
|
||||
timestamp = gr.Number(label="Time")
|
||||
timer.tick(lambda: round(time.time()), outputs=timestamp, api_name="timestamp")
|
||||
|
||||
number = gr.Number(lambda: random.randint(1, 10), every=timer, label="Random Number")
|
||||
with gr.Row():
|
||||
gr.Button("Start").click(lambda: gr.Timer(active=True), None, timer)
|
||||
gr.Button("Stop").click(lambda: gr.Timer(active=False), None, timer)
|
||||
gr.Button("Go Fast").click(lambda: 0.2, None, timer)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Variable Outputs
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
max_textboxes = 10
|
||||
|
||||
def variable_outputs(k):
|
||||
k = int(k)
|
||||
return [gr.Textbox(visible=True)]*k + [gr.Textbox(visible=False)]*(max_textboxes-k)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
s = gr.Slider(1, max_textboxes, value=max_textboxes, step=1, label="How many textboxes to show:")
|
||||
textboxes = []
|
||||
for i in range(max_textboxes):
|
||||
t = gr.Textbox(f"Textbox {i}")
|
||||
textboxes.append(t)
|
||||
|
||||
s.change(variable_outputs, s, textboxes)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Video Identity
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
from gradio.media import get_video
|
||||
|
||||
def video_identity(video):
|
||||
return video
|
||||
|
||||
# get_video() returns file paths to sample media included with Gradio
|
||||
demo = gr.Interface(video_identity,
|
||||
gr.Video(),
|
||||
"playable_video",
|
||||
examples=[
|
||||
get_video("world.mp4")
|
||||
],
|
||||
cache_examples=True,
|
||||
api_name="predict",)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
@@ -1,9 +1,9 @@
|
||||
---
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/huggingface-jobs"
|
||||
name: hugging-face-jobs
|
||||
description: "Run any workload on fully managed Hugging Face infrastructure. No local setup required—jobs run on cloud CPUs, GPUs, or TPUs and can persist results to the Hugging Face Hub."
|
||||
risk: safe
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/hugging-face-jobs"
|
||||
date_added: "2026-02-27"
|
||||
description: Run workloads on Hugging Face Jobs with managed CPUs, GPUs, TPUs, secrets, and Hub persistence.
|
||||
license: Complete terms in LICENSE.txt
|
||||
risk: unknown
|
||||
---
|
||||
|
||||
# Running Workloads on Hugging Face Jobs
|
||||
@@ -66,12 +66,15 @@ Before starting any job, verify:
|
||||
|
||||
**How to provide tokens:**
|
||||
```python
|
||||
{
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"} # Recommended: automatic token
|
||||
}
|
||||
# hf_jobs MCP tool — $HF_TOKEN is auto-replaced with real token:
|
||||
{"secrets": {"HF_TOKEN": "$HF_TOKEN"}}
|
||||
|
||||
# HfApi().run_uv_job() — MUST pass actual token:
|
||||
from huggingface_hub import get_token
|
||||
secrets={"HF_TOKEN": get_token()}
|
||||
```
|
||||
|
||||
**⚠️ CRITICAL:** The `$HF_TOKEN` placeholder is automatically replaced with your logged-in token. Never hardcode tokens in scripts.
|
||||
**⚠️ CRITICAL:** The `$HF_TOKEN` placeholder is ONLY auto-replaced by the `hf_jobs` MCP tool. When using `HfApi().run_uv_job()`, you MUST pass the real token via `get_token()`. Passing the literal string `"$HF_TOKEN"` results in a 9-character invalid token and 401 errors.
|
||||
|
||||
## Token Usage Guide
|
||||
|
||||
@@ -539,9 +542,12 @@ requests.post("https://your-api.com/results", json=results)
|
||||
|
||||
**In job submission:**
|
||||
```python
|
||||
{
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"} # Enables authentication
|
||||
}
|
||||
# hf_jobs MCP tool:
|
||||
{"secrets": {"HF_TOKEN": "$HF_TOKEN"}} # auto-replaced
|
||||
|
||||
# HfApi().run_uv_job():
|
||||
from huggingface_hub import get_token
|
||||
secrets={"HF_TOKEN": get_token()} # must pass real token
|
||||
```
|
||||
|
||||
**In script:**
|
||||
@@ -560,7 +566,7 @@ api.upload_file(...)
|
||||
|
||||
Before submitting:
|
||||
- [ ] Results persistence method chosen
|
||||
- [ ] `secrets={"HF_TOKEN": "$HF_TOKEN"}` if using Hub
|
||||
- [ ] Token in secrets if using Hub (MCP: `"$HF_TOKEN"`, Python API: `get_token()`)
|
||||
- [ ] Script handles missing token gracefully
|
||||
- [ ] Test persistence path works
|
||||
|
||||
@@ -950,7 +956,7 @@ hf_jobs("uv", {
|
||||
### Hub Push Failures
|
||||
|
||||
**Fix:**
|
||||
1. Add to job: `secrets={"HF_TOKEN": "$HF_TOKEN"}`
|
||||
1. Add token to secrets: MCP uses `"$HF_TOKEN"` (auto-replaced), Python API uses `get_token()` (must pass real token)
|
||||
2. Verify token in script: `assert "HF_TOKEN" in os.environ`
|
||||
3. Check token permissions
|
||||
4. Verify repo exists or can be created
|
||||
@@ -969,7 +975,7 @@ Add to PEP 723 header:
|
||||
|
||||
**Fix:**
|
||||
1. Check `hf_whoami()` works locally
|
||||
2. Verify `secrets={"HF_TOKEN": "$HF_TOKEN"}` in job config
|
||||
2. Verify token in secrets — MCP: `"$HF_TOKEN"`, Python API: `get_token()` (NOT `"$HF_TOKEN"`)
|
||||
3. Re-login: `hf auth login`
|
||||
4. Check token has required permissions
|
||||
|
||||
@@ -1017,7 +1023,7 @@ Add to PEP 723 header:
|
||||
2. **Jobs are asynchronous** - Don't wait/poll; let user check when ready
|
||||
3. **Always set timeout** - Default 30 min may be insufficient; set appropriate timeout
|
||||
4. **Always persist results** - Environment is ephemeral; without persistence, all work is lost
|
||||
5. **Use tokens securely** - Always use `secrets={"HF_TOKEN": "$HF_TOKEN"}` for Hub operations
|
||||
5. **Use tokens securely** - MCP: `secrets={"HF_TOKEN": "$HF_TOKEN"}`, Python API: `secrets={"HF_TOKEN": get_token()}` — `"$HF_TOKEN"` only works with MCP tool
|
||||
6. **Choose appropriate hardware** - Start small, scale up based on needs (see hardware guide)
|
||||
7. **Use UV scripts** - Default to `hf_jobs("uv", {...})` with inline scripts for Python workloads
|
||||
8. **Handle authentication** - Verify tokens are available before Hub operations
|
||||
@@ -1033,6 +1039,7 @@ Add to PEP 723 header:
|
||||
| List jobs | `hf_jobs("ps")` | `hf jobs ps` | `list_jobs()` |
|
||||
| View logs | `hf_jobs("logs", {...})` | `hf jobs logs <id>` | `fetch_job_logs(job_id)` |
|
||||
| Cancel job | `hf_jobs("cancel", {...})` | `hf jobs cancel <id>` | `cancel_job(job_id)` |
|
||||
| Schedule UV | `hf_jobs("scheduled uv", {...})` | - | `create_scheduled_uv_job()` |
|
||||
| Schedule Docker | `hf_jobs("scheduled run", {...})` | - | `create_scheduled_job()` |
|
||||
|
||||
| Schedule UV | `hf_jobs("scheduled uv", {...})` | `hf jobs scheduled uv run SCHEDULE script.py` | `create_scheduled_uv_job()` |
|
||||
| Schedule Docker | `hf_jobs("scheduled run", {...})` | `hf jobs scheduled run SCHEDULE image cmd` | `create_scheduled_job()` |
|
||||
| List scheduled | `hf_jobs("scheduled ps")` | `hf jobs scheduled ps` | `list_scheduled_jobs()` |
|
||||
| Delete scheduled | `hf_jobs("scheduled delete", {...})` | `hf jobs scheduled delete <id>` | `delete_scheduled_job()` |
|
||||
|
||||
@@ -0,0 +1,216 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>hf-jobs - Run Workloads on Hugging Face Jobs</title>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
|
||||
line-height: 1.6;
|
||||
color: #333;
|
||||
background: #f5f5f5;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
background: white;
|
||||
padding: 40px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: #ffd21e;
|
||||
background: #000;
|
||||
padding: 20px;
|
||||
margin: -40px -40px 30px -40px;
|
||||
border-radius: 8px 8px 0 0;
|
||||
}
|
||||
|
||||
h2 {
|
||||
color: #1e1e1e;
|
||||
margin-top: 30px;
|
||||
margin-bottom: 15px;
|
||||
padding-bottom: 10px;
|
||||
border-bottom: 2px solid #ffd21e;
|
||||
}
|
||||
|
||||
h3 {
|
||||
color: #555;
|
||||
margin-top: 20px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.description {
|
||||
background: #f9f9f9;
|
||||
padding: 20px;
|
||||
border-left: 4px solid #ffd21e;
|
||||
margin-bottom: 30px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.file-list {
|
||||
list-style: none;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.file-list li {
|
||||
padding: 12px;
|
||||
margin: 8px 0;
|
||||
background: #f9f9f9;
|
||||
border-radius: 4px;
|
||||
border-left: 3px solid #ffd21e;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.file-list li:hover {
|
||||
background: #f0f0f0;
|
||||
}
|
||||
|
||||
.file-list a {
|
||||
color: #0066cc;
|
||||
text-decoration: none;
|
||||
font-weight: 500;
|
||||
display: block;
|
||||
}
|
||||
|
||||
.file-list a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.file-path {
|
||||
color: #666;
|
||||
font-size: 0.9em;
|
||||
font-family: 'Monaco', 'Courier New', monospace;
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
.file-description {
|
||||
color: #777;
|
||||
font-size: 0.9em;
|
||||
margin-top: 4px;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.metadata {
|
||||
background: #f0f0f0;
|
||||
padding: 15px;
|
||||
border-radius: 4px;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
.metadata p {
|
||||
margin: 5px 0;
|
||||
}
|
||||
|
||||
.metadata strong {
|
||||
color: #333;
|
||||
}
|
||||
|
||||
.section {
|
||||
margin-bottom: 40px;
|
||||
}
|
||||
|
||||
code {
|
||||
background: #f4f4f4;
|
||||
padding: 2px 6px;
|
||||
border-radius: 3px;
|
||||
font-family: 'Monaco', 'Courier New', monospace;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Agent Skill : hf-jobs</h1>
|
||||
|
||||
<div class="description">
|
||||
<p><strong>Run any workload on Hugging Face Jobs.</strong></p>
|
||||
<p>Use this skill when you want to run GPU/CPU workloads (batch inference, synthetic data generation, dataset stats, experiments) on Hugging Face Jobs, with correct token handling and result persistence back to the Hub.</p>
|
||||
</div>
|
||||
|
||||
<div class="metadata">
|
||||
<p><strong>Skill Name:</strong> hf-jobs</p>
|
||||
<p><strong>Main Documentation:</strong> <a href="hf-jobs/SKILL.md">hf-jobs/SKILL.md</a></p>
|
||||
<p><strong>Scripts Directory:</strong> <code>hf-jobs/scripts/</code></p>
|
||||
<p><strong>References Directory:</strong> <code>hf-jobs/references/</code></p>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Overview</h2>
|
||||
<p>This skill focuses on running real workloads via Hugging Face Jobs. It includes ready-to-run UV scripts and guides for authentication (HF tokens), secrets vs env vars, timeouts, hardware selection, and pushing results to the Hub.</p>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Core Documentation</h2>
|
||||
<ul class="file-list">
|
||||
<li>
|
||||
<a href="hf-jobs/SKILL.md">SKILL.md</a>
|
||||
<div class="file-path">hf-jobs/SKILL.md</div>
|
||||
<div class="file-description">Complete skill documentation (how to submit jobs, tokens/secrets, timeouts, persistence, and how to use the bundled scripts)</div>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>References</h2>
|
||||
<ul class="file-list">
|
||||
<li>
|
||||
<a href="hf-jobs/references/token_usage.md">token_usage.md</a>
|
||||
<div class="file-path">hf-jobs/references/token_usage.md</div>
|
||||
<div class="file-description">Token best practices: secrets vs env, permissions, common errors (401/403), and secure patterns</div>
|
||||
</li>
|
||||
<li>
|
||||
<a href="hf-jobs/references/hub_saving.md">hub_saving.md</a>
|
||||
<div class="file-path">hf-jobs/references/hub_saving.md</div>
|
||||
<div class="file-description">How to persist results: push datasets/models/files to the Hub (ephemeral job filesystem)</div>
|
||||
</li>
|
||||
<li>
|
||||
<a href="hf-jobs/references/hardware_guide.md">hardware_guide.md</a>
|
||||
<div class="file-path">hf-jobs/references/hardware_guide.md</div>
|
||||
<div class="file-description">Flavor selection guidance for CPU/GPU/TPU workloads</div>
|
||||
</li>
|
||||
<li>
|
||||
<a href="hf-jobs/references/troubleshooting.md">troubleshooting.md</a>
|
||||
<div class="file-path">hf-jobs/references/troubleshooting.md</div>
|
||||
<div class="file-description">Common failure modes (timeouts, missing deps, OOM, auth) and fixes</div>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Scripts</h2>
|
||||
<ul class="file-list">
|
||||
<li>
|
||||
<a href="hf-jobs/scripts/generate-responses.py">generate-responses.py</a>
|
||||
<div class="file-path">hf-jobs/scripts/generate-responses.py</div>
|
||||
<div class="file-description">vLLM batch generation: load prompts/messages from a dataset, generate responses, push dataset + card to Hub</div>
|
||||
</li>
|
||||
<li>
|
||||
<a href="hf-jobs/scripts/cot-self-instruct.py">cot-self-instruct.py</a>
|
||||
<div class="file-path">hf-jobs/scripts/cot-self-instruct.py</div>
|
||||
<div class="file-description">CoT Self-Instruct synthetic data generation (reasoning/instruction) + optional filtering, pushes dataset + card</div>
|
||||
</li>
|
||||
<li>
|
||||
<a href="hf-jobs/scripts/finepdfs-stats.py">finepdfs-stats.py</a>
|
||||
<div class="file-path">hf-jobs/scripts/finepdfs-stats.py</div>
|
||||
<div class="file-description">Polars streaming stats over Hub parquet (finepdfs-edu); optional upload of computed stats to a dataset repo</div>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,336 @@
|
||||
# Hardware Selection Guide
|
||||
|
||||
Choosing the right hardware (flavor) is critical for cost-effective workloads.
|
||||
|
||||
> **Reference:** [HF Jobs Hardware Documentation](https://huggingface.co/docs/hub/en/spaces-config-reference) (updated 07/2025)
|
||||
|
||||
## Available Hardware
|
||||
|
||||
### CPU Flavors
|
||||
| Flavor | Description | Use Case |
|
||||
|--------|-------------|----------|
|
||||
| `cpu-basic` | Basic CPU instance | Testing, lightweight scripts |
|
||||
| `cpu-upgrade` | Enhanced CPU instance | Data processing, parallel workloads |
|
||||
|
||||
**Use cases:** Data processing, testing scripts, lightweight workloads
|
||||
**Not recommended for:** Model training, GPU-accelerated workloads
|
||||
|
||||
### GPU Flavors
|
||||
|
||||
| Flavor | GPU | VRAM | Use Case |
|
||||
|--------|-----|------|----------|
|
||||
| `t4-small` | NVIDIA T4 | 16GB | <1B models, demos, quick tests |
|
||||
| `t4-medium` | NVIDIA T4 | 16GB | 1-3B models, development |
|
||||
| `l4x1` | NVIDIA L4 | 24GB | 3-7B models, efficient workloads |
|
||||
| `l4x4` | 4x NVIDIA L4 | 96GB | Multi-GPU, parallel workloads |
|
||||
| `a10g-small` | NVIDIA A10G | 24GB | 3-7B models, production |
|
||||
| `a10g-large` | NVIDIA A10G | 24GB | 7-13B models, batch inference |
|
||||
| `a10g-largex2` | 2x NVIDIA A10G | 48GB | Multi-GPU, large models |
|
||||
| `a10g-largex4` | 4x NVIDIA A10G | 96GB | Multi-GPU, very large models |
|
||||
| `a100-large` | NVIDIA A100 | 40GB | 13B+ models, fastest GPU option |
|
||||
|
||||
### TPU Flavors
|
||||
|
||||
| Flavor | Configuration | Use Case |
|
||||
|--------|---------------|----------|
|
||||
| `v5e-1x1` | TPU v5e (1x1) | Small TPU workloads |
|
||||
| `v5e-2x2` | TPU v5e (2x2) | Medium TPU workloads |
|
||||
| `v5e-2x4` | TPU v5e (2x4) | Large TPU workloads |
|
||||
|
||||
**TPU Use Cases:**
|
||||
- JAX/Flax model training
|
||||
- Large-scale inference
|
||||
- TPU-optimized workloads
|
||||
|
||||
## Selection Guidelines
|
||||
|
||||
### By Workload Type
|
||||
|
||||
**Data Processing**
|
||||
- **Recommended:** `cpu-upgrade` or `l4x1`
|
||||
- **Use case:** Transform, filter, analyze datasets
|
||||
- **Batch size:** Depends on data size
|
||||
- **Time:** Varies by dataset size
|
||||
|
||||
**Batch Inference**
|
||||
- **Recommended:** `a10g-large` or `a100-large`
|
||||
- **Use case:** Run inference on thousands of samples
|
||||
- **Batch size:** 8-32 depending on model
|
||||
- **Time:** Depends on number of samples
|
||||
|
||||
**Experiments & Benchmarks**
|
||||
- **Recommended:** `a10g-small` or `a10g-large`
|
||||
- **Use case:** Reproducible ML experiments
|
||||
- **Batch size:** Varies
|
||||
- **Time:** Depends on experiment complexity
|
||||
|
||||
**Model Training** (see `model-trainer` skill for details)
|
||||
- **Recommended:** See model-trainer skill
|
||||
- **Use case:** Fine-tuning models
|
||||
- **Batch size:** Depends on model size
|
||||
- **Time:** Hours to days
|
||||
|
||||
**Synthetic Data Generation**
|
||||
- **Recommended:** `a10g-large` or `a100-large`
|
||||
- **Use case:** Generate datasets using LLMs
|
||||
- **Batch size:** Depends on generation method
|
||||
- **Time:** Hours for large datasets
|
||||
|
||||
### By Budget
|
||||
|
||||
**Minimal Budget (<$5 total)**
|
||||
- Use `cpu-basic` or `t4-small`
|
||||
- Process small datasets
|
||||
- Quick tests and demos
|
||||
|
||||
**Small Budget ($5-20)**
|
||||
- Use `t4-medium` or `a10g-small`
|
||||
- Process medium datasets
|
||||
- Run experiments
|
||||
|
||||
**Medium Budget ($20-50)**
|
||||
- Use `a10g-small` or `a10g-large`
|
||||
- Process large datasets
|
||||
- Production workloads
|
||||
|
||||
**Large Budget ($50-200)**
|
||||
- Use `a10g-large` or `a100-large`
|
||||
- Large-scale processing
|
||||
- Multiple experiments
|
||||
|
||||
### By Model Size (for inference/processing)
|
||||
|
||||
**Tiny Models (<1B parameters)**
|
||||
- **Recommended:** `t4-small`
|
||||
- **Example:** Qwen2.5-0.5B, TinyLlama
|
||||
- **Batch size:** 8-16
|
||||
|
||||
**Small Models (1-3B parameters)**
|
||||
- **Recommended:** `t4-medium` or `a10g-small`
|
||||
- **Example:** Qwen2.5-1.5B, Phi-2
|
||||
- **Batch size:** 4-8
|
||||
|
||||
**Medium Models (3-7B parameters)**
|
||||
- **Recommended:** `a10g-small` or `a10g-large`
|
||||
- **Example:** Qwen2.5-7B, Mistral-7B
|
||||
- **Batch size:** 2-4
|
||||
|
||||
**Large Models (7-13B parameters)**
|
||||
- **Recommended:** `a10g-large` or `a100-large`
|
||||
- **Example:** Llama-3-8B
|
||||
- **Batch size:** 1-2
|
||||
|
||||
**Very Large Models (13B+ parameters)**
|
||||
- **Recommended:** `a100-large`
|
||||
- **Example:** Llama-3-13B, Llama-3-70B
|
||||
- **Batch size:** 1
|
||||
|
||||
## Memory Considerations
|
||||
|
||||
### Estimating Memory Requirements
|
||||
|
||||
**For inference:**
|
||||
```
|
||||
Memory (GB) ≈ (Model params in billions) × 2-4
|
||||
```
|
||||
|
||||
**For training:**
|
||||
```
|
||||
Memory (GB) ≈ (Model params in billions) × 20 (full) or × 4 (LoRA)
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
- Qwen2.5-0.5B inference: ~1-2GB ✅ fits t4-small
|
||||
- Qwen2.5-7B inference: ~14-28GB ✅ fits a10g-large
|
||||
- Qwen2.5-7B training: ~140GB ❌ not feasible without LoRA
|
||||
|
||||
### Memory Optimization
|
||||
|
||||
If hitting memory limits:
|
||||
|
||||
1. **Reduce batch size**
|
||||
```python
|
||||
batch_size = 1
|
||||
```
|
||||
|
||||
2. **Process in chunks**
|
||||
```python
|
||||
for chunk in chunks:
|
||||
process(chunk)
|
||||
```
|
||||
|
||||
3. **Use smaller models**
|
||||
- Use quantized models
|
||||
- Use LoRA adapters
|
||||
|
||||
4. **Upgrade hardware**
|
||||
- cpu → t4 → a10g → a100
|
||||
|
||||
## Cost Estimation
|
||||
|
||||
### Formula
|
||||
|
||||
```
|
||||
Total Cost = (Hours of runtime) × (Cost per hour)
|
||||
```
|
||||
|
||||
### Example Calculations
|
||||
|
||||
**Data processing:**
|
||||
- Hardware: cpu-upgrade ($0.50/hour)
|
||||
- Time: 1 hour
|
||||
- Cost: $0.50
|
||||
|
||||
**Batch inference:**
|
||||
- Hardware: a10g-large ($5/hour)
|
||||
- Time: 2 hours
|
||||
- Cost: $10.00
|
||||
|
||||
**Experiments:**
|
||||
- Hardware: a10g-small ($3.50/hour)
|
||||
- Time: 4 hours
|
||||
- Cost: $14.00
|
||||
|
||||
### Cost Optimization Tips
|
||||
|
||||
1. **Start small:** Test on cpu-basic or t4-small
|
||||
2. **Monitor runtime:** Set appropriate timeouts
|
||||
3. **Optimize code:** Reduce unnecessary compute
|
||||
4. **Choose right hardware:** Don't over-provision
|
||||
5. **Use checkpoints:** Resume if job fails
|
||||
6. **Monitor costs:** Check running jobs regularly
|
||||
|
||||
## Multi-GPU Workloads
|
||||
|
||||
Multi-GPU flavors automatically distribute workloads:
|
||||
|
||||
**Multi-GPU flavors:**
|
||||
- `l4x4` - 4x L4 GPUs (96GB total VRAM)
|
||||
- `a10g-largex2` - 2x A10G GPUs (48GB total VRAM)
|
||||
- `a10g-largex4` - 4x A10G GPUs (96GB total VRAM)
|
||||
|
||||
**When to use:**
|
||||
- Large models (>13B parameters)
|
||||
- Need faster processing (linear speedup)
|
||||
- Large datasets (>100K samples)
|
||||
- Parallel workloads
|
||||
- Tensor parallelism for inference
|
||||
|
||||
**MCP Tool Example:**
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "process.py",
|
||||
"flavor": "a10g-largex2", # 2 GPUs
|
||||
"timeout": "4h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
})
|
||||
```
|
||||
|
||||
**CLI Equivalent:**
|
||||
```bash
|
||||
hf jobs uv run process.py --flavor a10g-largex2 --timeout 4h
|
||||
```
|
||||
|
||||
## Choosing Between Options
|
||||
|
||||
### CPU vs GPU
|
||||
|
||||
**Choose CPU when:**
|
||||
- No GPU acceleration needed
|
||||
- Data processing only
|
||||
- Budget constrained
|
||||
- Simple workloads
|
||||
|
||||
**Choose GPU when:**
|
||||
- Model inference/training
|
||||
- GPU-accelerated libraries
|
||||
- Need faster processing
|
||||
- Large models
|
||||
|
||||
### a10g vs a100
|
||||
|
||||
**Choose a10g when:**
|
||||
- Model <13B parameters
|
||||
- Budget conscious
|
||||
- Processing time not critical
|
||||
|
||||
**Choose a100 when:**
|
||||
- Model 13B+ parameters
|
||||
- Need fastest processing
|
||||
- Memory requirements high
|
||||
- Budget allows
|
||||
|
||||
### Single vs Multi-GPU
|
||||
|
||||
**Choose single GPU when:**
|
||||
- Model <7B parameters
|
||||
- Budget constrained
|
||||
- Simpler debugging
|
||||
|
||||
**Choose multi-GPU when:**
|
||||
- Model >13B parameters
|
||||
- Need faster processing
|
||||
- Large batch sizes required
|
||||
- Cost-effective for large jobs
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### All Available Flavors
|
||||
|
||||
```python
|
||||
# Official flavor list (updated 07/2025)
|
||||
FLAVORS = {
|
||||
# CPU
|
||||
"cpu-basic", # Testing, lightweight
|
||||
"cpu-upgrade", # Data processing
|
||||
|
||||
# GPU - Single
|
||||
"t4-small", # 16GB - <1B models
|
||||
"t4-medium", # 16GB - 1-3B models
|
||||
"l4x1", # 24GB - 3-7B models
|
||||
"a10g-small", # 24GB - 3-7B production
|
||||
"a10g-large", # 24GB - 7-13B models
|
||||
"a100-large", # 40GB - 13B+ models
|
||||
|
||||
# GPU - Multi
|
||||
"l4x4", # 4x L4 (96GB total)
|
||||
"a10g-largex2", # 2x A10G (48GB total)
|
||||
"a10g-largex4", # 4x A10G (96GB total)
|
||||
|
||||
# TPU
|
||||
"v5e-1x1", # TPU v5e 1x1
|
||||
"v5e-2x2", # TPU v5e 2x2
|
||||
"v5e-2x4", # TPU v5e 2x4
|
||||
}
|
||||
```
|
||||
|
||||
### Workload → Hardware Mapping
|
||||
|
||||
```python
|
||||
HARDWARE_MAP = {
|
||||
"data_processing": "cpu-upgrade",
|
||||
"batch_inference_small": "t4-small",
|
||||
"batch_inference_medium": "a10g-large",
|
||||
"batch_inference_large": "a100-large",
|
||||
"experiments": "a10g-small",
|
||||
"tpu_workloads": "v5e-1x1",
|
||||
"training": "see model-trainer skill"
|
||||
}
|
||||
```
|
||||
|
||||
### CLI Examples
|
||||
|
||||
```bash
|
||||
# CPU job
|
||||
hf jobs run python:3.12 python script.py
|
||||
|
||||
# GPU job
|
||||
hf jobs run --flavor a10g-large pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel python script.py
|
||||
|
||||
# TPU job
|
||||
hf jobs run --flavor v5e-1x1 your-tpu-image python script.py
|
||||
|
||||
# UV script with GPU
|
||||
hf jobs uv run --flavor a10g-small my_script.py
|
||||
```
|
||||
|
||||
@@ -0,0 +1,352 @@
|
||||
# Saving Results to Hugging Face Hub
|
||||
|
||||
**⚠️ CRITICAL:** Job environments are ephemeral. ALL results are lost when a job completes unless persisted to the Hub or external storage.
|
||||
|
||||
## Why Persistence is Required
|
||||
|
||||
When running on Hugging Face Jobs:
|
||||
- Environment is temporary
|
||||
- All files deleted on job completion
|
||||
- No local disk persistence
|
||||
- Cannot access results after job ends
|
||||
|
||||
**Without persistence, all work is permanently lost.**
|
||||
|
||||
## Persistence Options
|
||||
|
||||
### Option 1: Push to Hugging Face Hub (Recommended)
|
||||
|
||||
**For models:**
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
model.push_to_hub("username/model-name", token=os.environ.get("HF_TOKEN"))
|
||||
```
|
||||
|
||||
**For datasets:**
|
||||
```python
|
||||
from datasets import Dataset
|
||||
dataset.push_to_hub("username/dataset-name", token=os.environ.get("HF_TOKEN"))
|
||||
```
|
||||
|
||||
**For files/artifacts:**
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
api = HfApi(token=os.environ.get("HF_TOKEN"))
|
||||
api.upload_file(
|
||||
path_or_fileobj="results.json",
|
||||
path_in_repo="results.json",
|
||||
repo_id="username/results",
|
||||
repo_type="dataset"
|
||||
)
|
||||
```
|
||||
|
||||
### Option 2: External Storage
|
||||
|
||||
**S3:**
|
||||
```python
|
||||
import boto3
|
||||
s3 = boto3.client('s3')
|
||||
s3.upload_file('results.json', 'my-bucket', 'results.json')
|
||||
```
|
||||
|
||||
**Google Cloud Storage:**
|
||||
```python
|
||||
from google.cloud import storage
|
||||
client = storage.Client()
|
||||
bucket = client.bucket('my-bucket')
|
||||
blob = bucket.blob('results.json')
|
||||
blob.upload_from_filename('results.json')
|
||||
```
|
||||
|
||||
### Option 3: API Endpoint
|
||||
|
||||
```python
|
||||
import requests
|
||||
requests.post("https://your-api.com/results", json=results)
|
||||
```
|
||||
|
||||
## Required Configuration for Hub Push
|
||||
|
||||
### Job Configuration
|
||||
|
||||
**Always include HF_TOKEN:**
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "your_script.py",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"} # ✅ Required for Hub operations
|
||||
})
|
||||
```
|
||||
|
||||
### Script Configuration
|
||||
|
||||
**Verify token exists:**
|
||||
```python
|
||||
import os
|
||||
assert "HF_TOKEN" in os.environ, "HF_TOKEN required for Hub operations!"
|
||||
```
|
||||
|
||||
**Use token for Hub operations:**
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
# Auto-detects HF_TOKEN from environment
|
||||
api = HfApi()
|
||||
|
||||
# Or explicitly pass token
|
||||
api = HfApi(token=os.environ.get("HF_TOKEN"))
|
||||
```
|
||||
|
||||
## Complete Examples
|
||||
|
||||
### Example 1: Push Dataset
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": """
|
||||
# /// script
|
||||
# dependencies = ["datasets", "huggingface-hub"]
|
||||
# ///
|
||||
|
||||
import os
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
# Verify token
|
||||
assert "HF_TOKEN" in os.environ, "HF_TOKEN required!"
|
||||
|
||||
# Process data
|
||||
data = {"text": ["Sample 1", "Sample 2"]}
|
||||
dataset = Dataset.from_dict(data)
|
||||
|
||||
# Push to Hub
|
||||
dataset.push_to_hub("username/my-dataset")
|
||||
print("✅ Dataset pushed!")
|
||||
""",
|
||||
"flavor": "cpu-basic",
|
||||
"timeout": "30m",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
})
|
||||
```
|
||||
|
||||
### Example 2: Push Model
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": """
|
||||
# /// script
|
||||
# dependencies = ["transformers"]
|
||||
# ///
|
||||
|
||||
import os
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
# Verify token
|
||||
assert "HF_TOKEN" in os.environ, "HF_TOKEN required!"
|
||||
|
||||
# Load and process model
|
||||
model = AutoModel.from_pretrained("base-model")
|
||||
tokenizer = AutoTokenizer.from_pretrained("base-model")
|
||||
# ... process model ...
|
||||
|
||||
# Push to Hub
|
||||
model.push_to_hub("username/my-model")
|
||||
tokenizer.push_to_hub("username/my-model")
|
||||
print("✅ Model pushed!")
|
||||
""",
|
||||
"flavor": "a10g-large",
|
||||
"timeout": "2h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
})
|
||||
```
|
||||
|
||||
### Example 3: Push Artifacts
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": """
|
||||
# /// script
|
||||
# dependencies = ["huggingface-hub", "pandas"]
|
||||
# ///
|
||||
|
||||
import os
|
||||
import json
|
||||
import pandas as pd
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
# Verify token
|
||||
assert "HF_TOKEN" in os.environ, "HF_TOKEN required!"
|
||||
|
||||
# Generate results
|
||||
results = {"accuracy": 0.95, "loss": 0.05}
|
||||
df = pd.DataFrame([results])
|
||||
|
||||
# Save files
|
||||
with open("results.json", "w") as f:
|
||||
json.dump(results, f)
|
||||
df.to_csv("results.csv", index=False)
|
||||
|
||||
# Push to Hub
|
||||
api = HfApi()
|
||||
api.upload_file("results.json", "results.json", "username/results", repo_type="dataset")
|
||||
api.upload_file("results.csv", "results.csv", "username/results", repo_type="dataset")
|
||||
print("✅ Results pushed!")
|
||||
""",
|
||||
"flavor": "cpu-basic",
|
||||
"timeout": "30m",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
})
|
||||
```
|
||||
|
||||
## Authentication Methods
|
||||
|
||||
### Method 1: Automatic Token (Recommended)
|
||||
|
||||
```python
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
```
|
||||
|
||||
Uses your logged-in Hugging Face token automatically.
|
||||
|
||||
### Method 2: Explicit Token
|
||||
|
||||
```python
|
||||
"secrets": {"HF_TOKEN": "hf_abc123..."}
|
||||
```
|
||||
|
||||
Provide token explicitly (not recommended for security).
|
||||
|
||||
### Method 3: Environment Variable
|
||||
|
||||
```python
|
||||
"env": {"HF_TOKEN": "hf_abc123..."}
|
||||
```
|
||||
|
||||
Pass as regular environment variable (less secure than secrets).
|
||||
|
||||
**Always prefer Method 1** for security and convenience.
|
||||
|
||||
## Verification Checklist
|
||||
|
||||
Before submitting any job that saves to Hub, verify:
|
||||
|
||||
- [ ] `secrets={"HF_TOKEN": "$HF_TOKEN"}` in job config
|
||||
- [ ] Script checks for token: `assert "HF_TOKEN" in os.environ`
|
||||
- [ ] Hub push code included in script
|
||||
- [ ] Repository name doesn't conflict with existing repos
|
||||
- [ ] You have write access to the target namespace
|
||||
|
||||
## Repository Setup
|
||||
|
||||
### Automatic Creation
|
||||
|
||||
If repository doesn't exist, it's created automatically when first pushing (if token has write permissions).
|
||||
|
||||
### Manual Creation
|
||||
|
||||
Create repository before pushing:
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(
|
||||
repo_id="username/repo-name",
|
||||
repo_type="model", # or "dataset"
|
||||
private=False, # or True for private repo
|
||||
)
|
||||
```
|
||||
|
||||
### Repository Naming
|
||||
|
||||
**Valid names:**
|
||||
- `username/my-model`
|
||||
- `username/model-name`
|
||||
- `organization/model-name`
|
||||
|
||||
**Invalid names:**
|
||||
- `model-name` (missing username)
|
||||
- `username/model name` (spaces not allowed)
|
||||
- `username/MODEL` (uppercase discouraged)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Error: 401 Unauthorized
|
||||
|
||||
**Cause:** HF_TOKEN not provided or invalid
|
||||
|
||||
**Solutions:**
|
||||
1. Verify `secrets={"HF_TOKEN": "$HF_TOKEN"}` in job config
|
||||
2. Check you're logged in: `hf_whoami()`
|
||||
3. Re-login: `hf auth login`
|
||||
|
||||
### Error: 403 Forbidden
|
||||
|
||||
**Cause:** No write access to repository
|
||||
|
||||
**Solutions:**
|
||||
1. Check repository namespace matches your username
|
||||
2. Verify you're a member of organization (if using org namespace)
|
||||
3. Check token has write permissions
|
||||
|
||||
### Error: Repository not found
|
||||
|
||||
**Cause:** Repository doesn't exist and auto-creation failed
|
||||
|
||||
**Solutions:**
|
||||
1. Manually create repository first
|
||||
2. Check repository name format
|
||||
3. Verify namespace exists
|
||||
|
||||
### Error: Push failed
|
||||
|
||||
**Cause:** Network issues or Hub unavailable
|
||||
|
||||
**Solutions:**
|
||||
1. Check logs for specific error
|
||||
2. Verify token is valid
|
||||
3. Retry push operation
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always verify token exists** before Hub operations
|
||||
2. **Use descriptive repo names** (e.g., `my-experiment-results` not `results`)
|
||||
3. **Push incrementally** for large results (use checkpoints)
|
||||
4. **Verify push success** in logs before job completes
|
||||
5. **Use appropriate repo types** (model vs dataset)
|
||||
6. **Add README** with result descriptions
|
||||
7. **Tag repos** with relevant tags
|
||||
|
||||
## Monitoring Push Progress
|
||||
|
||||
Check logs for push progress:
|
||||
|
||||
**MCP Tool:**
|
||||
```python
|
||||
hf_jobs("logs", {"job_id": "your-job-id"})
|
||||
```
|
||||
|
||||
**CLI:**
|
||||
```bash
|
||||
hf jobs logs <job-id>
|
||||
```
|
||||
|
||||
**Python API:**
|
||||
```python
|
||||
from huggingface_hub import fetch_job_logs
|
||||
for log in fetch_job_logs(job_id="your-job-id"):
|
||||
print(log)
|
||||
```
|
||||
|
||||
**Look for:**
|
||||
```
|
||||
Pushing to username/repo-name...
|
||||
Upload file results.json: 100%
|
||||
✅ Push successful
|
||||
```
|
||||
|
||||
## Key Takeaway
|
||||
|
||||
**Without `secrets={"HF_TOKEN": "$HF_TOKEN"}` and persistence code, all results are permanently lost.**
|
||||
|
||||
Always verify both are configured before submitting any job that produces results.
|
||||
|
||||
@@ -0,0 +1,570 @@
|
||||
# Token Usage Guide for Hugging Face Jobs
|
||||
|
||||
**⚠️ CRITICAL:** Proper token usage is essential for any job that interacts with the Hugging Face Hub.
|
||||
|
||||
## Overview
|
||||
|
||||
Hugging Face tokens are authentication credentials that allow your jobs to interact with the Hub. They're required for:
|
||||
- Pushing models/datasets to Hub
|
||||
- Accessing private repositories
|
||||
- Creating new repositories
|
||||
- Using Hub APIs programmatically
|
||||
- Any authenticated Hub operations
|
||||
|
||||
## Token Types
|
||||
|
||||
### Read Token
|
||||
- **Permissions:** Download models/datasets, read private repos
|
||||
- **Use case:** Jobs that only need to download/read content
|
||||
- **Creation:** https://huggingface.co/settings/tokens
|
||||
|
||||
### Write Token
|
||||
- **Permissions:** Push models/datasets, create repos, modify content
|
||||
- **Use case:** Jobs that need to upload results (most common)
|
||||
- **Creation:** https://huggingface.co/settings/tokens
|
||||
- **⚠️ Required for:** Pushing models, datasets, or any uploads
|
||||
|
||||
### Organization Token
|
||||
- **Permissions:** Act on behalf of an organization
|
||||
- **Use case:** Jobs running under organization namespace
|
||||
- **Creation:** Organization settings → Tokens
|
||||
|
||||
## Providing Tokens to Jobs
|
||||
|
||||
### Method 1: `hf_jobs` MCP tool with `$HF_TOKEN` (Recommended) ⭐
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "your_script.py",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"} # ✅ Automatic replacement
|
||||
})
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
1. `$HF_TOKEN` is a placeholder that gets replaced with your actual token
|
||||
2. Uses the token from your logged-in session (`hf auth login`)
|
||||
3. Token is encrypted server-side when passed as a secret
|
||||
4. Most secure and convenient method
|
||||
|
||||
**Benefits:**
|
||||
- ✅ No token exposure in code
|
||||
- ✅ Uses your current login session
|
||||
- ✅ Automatically updated if you re-login
|
||||
- ✅ Works seamlessly with MCP tools
|
||||
- ✅ Token encrypted server-side
|
||||
|
||||
**Requirements:**
|
||||
- Must be logged in: `hf auth login` or `hf_whoami()` works
|
||||
- Token must have required permissions
|
||||
|
||||
**⚠️ CRITICAL:** `$HF_TOKEN` auto-replacement is an `hf_jobs` MCP tool feature ONLY. It does NOT work with `HfApi().run_uv_job()` — see Method 1b below.
|
||||
|
||||
### Method 1b: `HfApi().run_uv_job()` with `get_token()` (Required for Python API)
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi, get_token
|
||||
api = HfApi()
|
||||
api.run_uv_job(
|
||||
script="your_script.py",
|
||||
secrets={"HF_TOKEN": get_token()}, # ✅ Passes actual token value
|
||||
)
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
1. `get_token()` retrieves the token from your logged-in session
|
||||
2. The actual token value is passed to the `secrets` parameter
|
||||
3. Token is encrypted server-side
|
||||
|
||||
**Why `"$HF_TOKEN"` fails with `HfApi().run_uv_job()`:**
|
||||
- The Python API passes the literal string `"$HF_TOKEN"` (9 characters) as the token
|
||||
- The Jobs server receives this invalid string instead of a real token
|
||||
- Result: `401 Unauthorized` errors when the script tries to authenticate
|
||||
- You MUST use `get_token()` from `huggingface_hub` to get the real token
|
||||
|
||||
### Method 2: Explicit Token (Not Recommended)
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "your_script.py",
|
||||
"secrets": {"HF_TOKEN": "hf_abc123..."} # ⚠️ Hardcoded token
|
||||
})
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Only if automatic token doesn't work
|
||||
- Testing with a specific token
|
||||
- Organization tokens (use with caution)
|
||||
|
||||
**Security concerns:**
|
||||
- ❌ Token visible in code/logs
|
||||
- ❌ Must manually update if token rotates
|
||||
- ❌ Risk of token exposure
|
||||
- ❌ Not recommended for production
|
||||
|
||||
### Method 3: Environment Variable (Less Secure)
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "your_script.py",
|
||||
"env": {"HF_TOKEN": "hf_abc123..."} # ⚠️ Less secure than secrets
|
||||
})
|
||||
```
|
||||
|
||||
**Difference from secrets:**
|
||||
- `env` variables are visible in job logs
|
||||
- `secrets` are encrypted server-side
|
||||
- Always prefer `secrets` for tokens
|
||||
|
||||
**When to use:**
|
||||
- Only for non-sensitive configuration
|
||||
- Never use for tokens (use `secrets` instead)
|
||||
|
||||
## Using Tokens in Scripts
|
||||
|
||||
### Accessing Tokens
|
||||
|
||||
Tokens passed via `secrets` are available as environment variables in your script:
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Get token from environment
|
||||
token = os.environ.get("HF_TOKEN")
|
||||
|
||||
# Verify token exists
|
||||
if not token:
|
||||
raise ValueError("HF_TOKEN not found in environment!")
|
||||
```
|
||||
|
||||
### Using with Hugging Face Hub
|
||||
|
||||
**Option 1: Explicit token parameter**
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi(token=os.environ.get("HF_TOKEN"))
|
||||
api.upload_file(...)
|
||||
```
|
||||
|
||||
**Option 2: Auto-detection (Recommended)**
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
# Automatically uses HF_TOKEN env var
|
||||
api = HfApi() # ✅ Simpler, uses token from environment
|
||||
api.upload_file(...)
|
||||
```
|
||||
|
||||
**Option 3: With transformers/datasets**
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
from datasets import load_dataset
|
||||
|
||||
# Auto-detects HF_TOKEN from environment
|
||||
model = AutoModel.from_pretrained("username/model")
|
||||
dataset = load_dataset("username/dataset")
|
||||
|
||||
# For push operations, token is auto-detected
|
||||
model.push_to_hub("username/new-model")
|
||||
dataset.push_to_hub("username/new-dataset")
|
||||
```
|
||||
|
||||
### Complete Example
|
||||
|
||||
```python
|
||||
# /// script
|
||||
# dependencies = ["huggingface-hub", "datasets"]
|
||||
# ///
|
||||
|
||||
import os
|
||||
from huggingface_hub import HfApi
|
||||
from datasets import Dataset
|
||||
|
||||
# Verify token is available
|
||||
assert "HF_TOKEN" in os.environ, "HF_TOKEN required for Hub operations!"
|
||||
|
||||
# Use token for Hub operations
|
||||
api = HfApi() # Auto-detects HF_TOKEN
|
||||
|
||||
# Create and push dataset
|
||||
data = {"text": ["Hello", "World"]}
|
||||
dataset = Dataset.from_dict(data)
|
||||
|
||||
# Push to Hub (token auto-detected)
|
||||
dataset.push_to_hub("username/my-dataset")
|
||||
|
||||
print("✅ Dataset pushed successfully!")
|
||||
```
|
||||
|
||||
## Token Verification
|
||||
|
||||
### Check Authentication Locally
|
||||
|
||||
```python
|
||||
from huggingface_hub import whoami
|
||||
|
||||
try:
|
||||
user_info = whoami()
|
||||
print(f"✅ Logged in as: {user_info['name']}")
|
||||
except Exception as e:
|
||||
print(f"❌ Not authenticated: {e}")
|
||||
```
|
||||
|
||||
### Verify Token in Job
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Check token exists
|
||||
if "HF_TOKEN" not in os.environ:
|
||||
raise ValueError("HF_TOKEN not found in environment!")
|
||||
|
||||
token = os.environ["HF_TOKEN"]
|
||||
|
||||
# Verify token format (should start with "hf_")
|
||||
if not token.startswith("hf_"):
|
||||
raise ValueError(f"Invalid token format: {token[:10]}...")
|
||||
|
||||
# Test token works
|
||||
from huggingface_hub import whoami
|
||||
try:
|
||||
user_info = whoami(token=token)
|
||||
print(f"✅ Token valid for user: {user_info['name']}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Token validation failed: {e}")
|
||||
```
|
||||
|
||||
## Common Token Issues
|
||||
|
||||
### Error: 401 Unauthorized
|
||||
|
||||
**Symptoms:**
|
||||
```
|
||||
401 Client Error: Unauthorized for url: https://huggingface.co/api/...
|
||||
```
|
||||
|
||||
**Causes:**
|
||||
1. Token missing from job
|
||||
2. Token invalid or expired
|
||||
3. Token not passed correctly
|
||||
|
||||
**Solutions:**
|
||||
1. Add `secrets={"HF_TOKEN": "$HF_TOKEN"}` to job config
|
||||
2. Verify `hf_whoami()` works locally
|
||||
3. Re-login: `hf auth login`
|
||||
4. Check token hasn't expired
|
||||
|
||||
**Verification:**
|
||||
```python
|
||||
# In your script
|
||||
import os
|
||||
assert "HF_TOKEN" in os.environ, "HF_TOKEN missing!"
|
||||
```
|
||||
|
||||
### Error: 403 Forbidden
|
||||
|
||||
**Symptoms:**
|
||||
```
|
||||
403 Client Error: Forbidden for url: https://huggingface.co/api/...
|
||||
```
|
||||
|
||||
**Causes:**
|
||||
1. Token lacks required permissions (read-only token used for write)
|
||||
2. No access to private repository
|
||||
3. Organization permissions insufficient
|
||||
|
||||
**Solutions:**
|
||||
1. Ensure token has write permissions
|
||||
2. Check token type at https://huggingface.co/settings/tokens
|
||||
3. Verify access to target repository
|
||||
4. Use organization token if needed
|
||||
|
||||
**Check token permissions:**
|
||||
```python
|
||||
from huggingface_hub import whoami
|
||||
|
||||
user_info = whoami()
|
||||
print(f"User: {user_info['name']}")
|
||||
print(f"Type: {user_info.get('type', 'user')}")
|
||||
```
|
||||
|
||||
### Error: Token not found in environment
|
||||
|
||||
**Symptoms:**
|
||||
```
|
||||
KeyError: 'HF_TOKEN'
|
||||
ValueError: HF_TOKEN not found
|
||||
```
|
||||
|
||||
**Causes:**
|
||||
1. `secrets` not passed in job config
|
||||
2. Wrong key name (should be `HF_TOKEN`)
|
||||
3. Using `env` instead of `secrets`
|
||||
|
||||
**Solutions:**
|
||||
1. Use `secrets={"HF_TOKEN": "$HF_TOKEN"}` (not `env`)
|
||||
2. Verify key name is exactly `HF_TOKEN`
|
||||
3. Check job config syntax
|
||||
|
||||
**Correct configuration:**
|
||||
```python
|
||||
# ✅ Correct
|
||||
hf_jobs("uv", {
|
||||
"script": "...",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
})
|
||||
|
||||
# ❌ Wrong - using env instead of secrets
|
||||
hf_jobs("uv", {
|
||||
"script": "...",
|
||||
"env": {"HF_TOKEN": "$HF_TOKEN"} # Less secure
|
||||
})
|
||||
|
||||
# ❌ Wrong - wrong key name
|
||||
hf_jobs("uv", {
|
||||
"script": "...",
|
||||
"secrets": {"TOKEN": "$HF_TOKEN"} # Wrong key
|
||||
})
|
||||
```
|
||||
|
||||
### Error: Repository access denied
|
||||
|
||||
**Symptoms:**
|
||||
```
|
||||
403 Client Error: Forbidden
|
||||
Repository not found or access denied
|
||||
```
|
||||
|
||||
**Causes:**
|
||||
1. Token doesn't have access to private repo
|
||||
2. Repository doesn't exist and can't be created
|
||||
3. Wrong namespace
|
||||
|
||||
**Solutions:**
|
||||
1. Use token from account with access
|
||||
2. Verify repo visibility (public vs private)
|
||||
3. Check namespace matches token owner
|
||||
4. Create repo first if needed
|
||||
|
||||
**Check repository access:**
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
try:
|
||||
repo_info = api.repo_info("username/repo-name")
|
||||
print(f"✅ Access granted: {repo_info.id}")
|
||||
except Exception as e:
|
||||
print(f"❌ Access denied: {e}")
|
||||
```
|
||||
|
||||
## Token Security Best Practices
|
||||
|
||||
### 1. Never Commit Tokens
|
||||
|
||||
**❌ Bad:**
|
||||
```python
|
||||
# Never do this!
|
||||
token = "hf_abc123xyz..."
|
||||
api = HfApi(token=token)
|
||||
```
|
||||
|
||||
**✅ Good:**
|
||||
```python
|
||||
# Use environment variable
|
||||
token = os.environ.get("HF_TOKEN")
|
||||
api = HfApi(token=token)
|
||||
```
|
||||
|
||||
### 2. Use Secrets, Not Environment Variables
|
||||
|
||||
**❌ Bad:**
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "...",
|
||||
"env": {"HF_TOKEN": "$HF_TOKEN"} # Visible in logs
|
||||
})
|
||||
```
|
||||
|
||||
**✅ Good:**
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "...",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"} # Encrypted server-side
|
||||
})
|
||||
```
|
||||
|
||||
### 3. Use Automatic Token Replacement
|
||||
|
||||
**❌ Bad:**
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "...",
|
||||
"secrets": {"HF_TOKEN": "hf_abc123..."} # Hardcoded
|
||||
})
|
||||
```
|
||||
|
||||
**✅ Good:**
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "...",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"} # Automatic
|
||||
})
|
||||
```
|
||||
|
||||
### 4. Rotate Tokens Regularly
|
||||
|
||||
- Generate new tokens periodically
|
||||
- Revoke old tokens
|
||||
- Update job configurations
|
||||
- Monitor token usage
|
||||
|
||||
### 5. Use Minimal Permissions
|
||||
|
||||
- Create tokens with only needed permissions
|
||||
- Use read tokens when write isn't needed
|
||||
- Don't use admin tokens for regular jobs
|
||||
|
||||
### 6. Don't Share Tokens
|
||||
|
||||
- Each user should use their own token
|
||||
- Don't commit tokens to repositories
|
||||
- Don't share tokens in logs or messages
|
||||
|
||||
### 7. Monitor Token Usage
|
||||
|
||||
- Check token activity in Hub settings
|
||||
- Review job logs for token issues
|
||||
- Set up alerts for unauthorized access
|
||||
|
||||
## Token Workflow Examples
|
||||
|
||||
### Example 1: Push Model to Hub
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": """
|
||||
# /// script
|
||||
# dependencies = ["transformers"]
|
||||
# ///
|
||||
|
||||
import os
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
# Verify token
|
||||
assert "HF_TOKEN" in os.environ, "HF_TOKEN required!"
|
||||
|
||||
# Load and process model
|
||||
model = AutoModel.from_pretrained("base-model")
|
||||
# ... process model ...
|
||||
|
||||
# Push to Hub (token auto-detected)
|
||||
model.push_to_hub("username/my-model")
|
||||
print("✅ Model pushed!")
|
||||
""",
|
||||
"flavor": "a10g-large",
|
||||
"timeout": "2h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"} # ✅ Token provided
|
||||
})
|
||||
```
|
||||
|
||||
### Example 2: Access Private Dataset
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": """
|
||||
# /// script
|
||||
# dependencies = ["datasets"]
|
||||
# ///
|
||||
|
||||
import os
|
||||
from datasets import load_dataset
|
||||
|
||||
# Verify token
|
||||
assert "HF_TOKEN" in os.environ, "HF_TOKEN required!"
|
||||
|
||||
# Load private dataset (token auto-detected)
|
||||
dataset = load_dataset("private-org/private-dataset")
|
||||
print(f"✅ Loaded {len(dataset)} examples")
|
||||
""",
|
||||
"flavor": "cpu-basic",
|
||||
"timeout": "30m",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"} # ✅ Token provided
|
||||
})
|
||||
```
|
||||
|
||||
### Example 3: Create and Push Dataset
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": """
|
||||
# /// script
|
||||
# dependencies = ["datasets", "huggingface-hub"]
|
||||
# ///
|
||||
|
||||
import os
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
# Verify token
|
||||
assert "HF_TOKEN" in os.environ, "HF_TOKEN required!"
|
||||
|
||||
# Create dataset
|
||||
data = {"text": ["Sample 1", "Sample 2"]}
|
||||
dataset = Dataset.from_dict(data)
|
||||
|
||||
# Push to Hub
|
||||
api = HfApi() # Auto-detects HF_TOKEN
|
||||
dataset.push_to_hub("username/my-dataset")
|
||||
print("✅ Dataset pushed!")
|
||||
""",
|
||||
"flavor": "cpu-basic",
|
||||
"timeout": "30m",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"} # ✅ Token provided
|
||||
})
|
||||
```
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Token Checklist
|
||||
|
||||
Before submitting a job that uses Hub:
|
||||
|
||||
- [ ] Job includes `secrets={"HF_TOKEN": "$HF_TOKEN"}`
|
||||
- [ ] Script checks for token: `assert "HF_TOKEN" in os.environ`
|
||||
- [ ] Token has required permissions (read/write)
|
||||
- [ ] User is logged in: `hf_whoami()` works
|
||||
- [ ] Token not hardcoded in script
|
||||
- [ ] Using `secrets` not `env` for token
|
||||
|
||||
### Common Patterns
|
||||
|
||||
**Pattern 1: Auto-detect token**
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
api = HfApi() # Uses HF_TOKEN from environment
|
||||
```
|
||||
|
||||
**Pattern 2: Explicit token**
|
||||
```python
|
||||
import os
|
||||
from huggingface_hub import HfApi
|
||||
api = HfApi(token=os.environ.get("HF_TOKEN"))
|
||||
```
|
||||
|
||||
**Pattern 3: Verify token**
|
||||
```python
|
||||
import os
|
||||
assert "HF_TOKEN" in os.environ, "HF_TOKEN required!"
|
||||
```
|
||||
|
||||
## Key Takeaways
|
||||
|
||||
1. **Always use `secrets={"HF_TOKEN": "$HF_TOKEN"}`** for Hub operations
|
||||
2. **Never hardcode tokens** in scripts or job configs
|
||||
3. **Verify token exists** in script before Hub operations
|
||||
4. **Use auto-detection** when possible (`HfApi()` without token parameter)
|
||||
5. **Check permissions** - ensure token has required access
|
||||
6. **Monitor token usage** - review activity regularly
|
||||
7. **Rotate tokens** - generate new tokens periodically
|
||||
|
||||
@@ -0,0 +1,475 @@
|
||||
# Troubleshooting Guide
|
||||
|
||||
Common issues and solutions for Hugging Face Jobs.
|
||||
|
||||
## Authentication Issues
|
||||
|
||||
### Error: 401 Unauthorized
|
||||
|
||||
**Symptoms:**
|
||||
```
|
||||
401 Client Error: Unauthorized for url: https://huggingface.co/api/...
|
||||
```
|
||||
|
||||
**Causes:**
|
||||
- Token missing from job
|
||||
- Token invalid or expired
|
||||
- Token not passed correctly
|
||||
|
||||
**Solutions:**
|
||||
1. Add token to secrets: `hf_jobs` MCP uses `"$HF_TOKEN"` (auto-replaced); `HfApi().run_uv_job()` MUST use `get_token()` from `huggingface_hub` (the literal string `"$HF_TOKEN"` will NOT work with the Python API)
|
||||
2. Verify `hf_whoami()` works locally
|
||||
3. Re-login: `hf auth login`
|
||||
4. Check token hasn't expired
|
||||
|
||||
**Verification:**
|
||||
```python
|
||||
# In your script
|
||||
import os
|
||||
assert "HF_TOKEN" in os.environ, "HF_TOKEN missing!"
|
||||
```
|
||||
|
||||
### Error: 403 Forbidden
|
||||
|
||||
**Symptoms:**
|
||||
```
|
||||
403 Client Error: Forbidden for url: https://huggingface.co/api/...
|
||||
```
|
||||
|
||||
**Causes:**
|
||||
- Token lacks required permissions
|
||||
- No access to private repository
|
||||
- Organization permissions insufficient
|
||||
|
||||
**Solutions:**
|
||||
1. Ensure token has write permissions
|
||||
2. Check token type at https://huggingface.co/settings/tokens
|
||||
3. Verify access to target repository
|
||||
4. Use organization token if needed
|
||||
|
||||
### Error: Token not found in environment
|
||||
|
||||
**Symptoms:**
|
||||
```
|
||||
KeyError: 'HF_TOKEN'
|
||||
ValueError: HF_TOKEN not found
|
||||
```
|
||||
|
||||
**Causes:**
|
||||
- `secrets` not passed in job config
|
||||
- Wrong key name (should be `HF_TOKEN`)
|
||||
- Using `env` instead of `secrets`
|
||||
|
||||
**Solutions:**
|
||||
1. Use `secrets` (not `env`) — with `hf_jobs` MCP: `"$HF_TOKEN"`; with `HfApi().run_uv_job()`: `get_token()`
|
||||
2. Verify key name is exactly `HF_TOKEN`
|
||||
3. Check job config syntax
|
||||
|
||||
## Job Execution Issues
|
||||
|
||||
### Error: Job Timeout
|
||||
|
||||
**Symptoms:**
|
||||
- Job stops unexpectedly
|
||||
- Status shows "TIMEOUT"
|
||||
- Partial results only
|
||||
|
||||
**Causes:**
|
||||
- Default 30min timeout exceeded
|
||||
- Job takes longer than expected
|
||||
- No timeout specified
|
||||
|
||||
**Solutions:**
|
||||
1. Check logs for actual runtime
|
||||
2. Increase timeout with buffer: `"timeout": "3h"`
|
||||
3. Optimize code for faster execution
|
||||
4. Process data in chunks
|
||||
5. Add 20-30% buffer to estimated time
|
||||
|
||||
**MCP Tool Example:**
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "...",
|
||||
"timeout": "2h" # Set appropriate timeout
|
||||
})
|
||||
```
|
||||
|
||||
**Python API Example:**
|
||||
```python
|
||||
from huggingface_hub import run_uv_job, inspect_job, fetch_job_logs
|
||||
|
||||
job = run_uv_job("script.py", timeout="4h")
|
||||
|
||||
# Check if job failed
|
||||
job_info = inspect_job(job_id=job.id)
|
||||
if job_info.status.stage == "ERROR":
|
||||
print(f"Job failed: {job_info.status.message}")
|
||||
# Check logs for details
|
||||
for log in fetch_job_logs(job_id=job.id):
|
||||
print(log)
|
||||
```
|
||||
|
||||
### Error: Out of Memory (OOM)
|
||||
|
||||
**Symptoms:**
|
||||
```
|
||||
RuntimeError: CUDA out of memory
|
||||
MemoryError: Unable to allocate array
|
||||
```
|
||||
|
||||
**Causes:**
|
||||
- Batch size too large
|
||||
- Model too large for hardware
|
||||
- Insufficient GPU memory
|
||||
|
||||
**Solutions:**
|
||||
1. Reduce batch size
|
||||
2. Process data in smaller chunks
|
||||
3. Upgrade hardware: cpu → t4 → a10g → a100
|
||||
4. Use smaller models or quantization
|
||||
5. Enable gradient checkpointing (for training)
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
# Reduce batch size
|
||||
batch_size = 1
|
||||
|
||||
# Process in chunks
|
||||
for chunk in chunks:
|
||||
process(chunk)
|
||||
```
|
||||
|
||||
### Error: Missing Dependencies
|
||||
|
||||
**Symptoms:**
|
||||
```
|
||||
ModuleNotFoundError: No module named 'package_name'
|
||||
ImportError: cannot import name 'X'
|
||||
```
|
||||
|
||||
**Causes:**
|
||||
- Package not in dependencies
|
||||
- Wrong package name
|
||||
- Version mismatch
|
||||
|
||||
**Solutions:**
|
||||
1. Add to PEP 723 header:
|
||||
```python
|
||||
# /// script
|
||||
# dependencies = ["package-name>=1.0.0"]
|
||||
# ///
|
||||
```
|
||||
2. Check package name spelling
|
||||
3. Specify version if needed
|
||||
4. Check package availability
|
||||
|
||||
### Error: Script Not Found
|
||||
|
||||
**Symptoms:**
|
||||
```
|
||||
FileNotFoundError: script.py not found
|
||||
```
|
||||
|
||||
**Causes:**
|
||||
- Local file path used (not supported)
|
||||
- URL incorrect
|
||||
- Script not accessible
|
||||
|
||||
**Solutions:**
|
||||
1. Use inline script (recommended)
|
||||
2. Use publicly accessible URL
|
||||
3. Upload script to Hub first
|
||||
4. Check URL is correct
|
||||
|
||||
**Correct approaches:**
|
||||
```python
|
||||
# ✅ Inline code
|
||||
hf_jobs("uv", {"script": "# /// script\n# dependencies = [...]\n# ///\n\n<code>"})
|
||||
|
||||
# ✅ From URL
|
||||
hf_jobs("uv", {"script": "https://huggingface.co/user/repo/resolve/main/script.py"})
|
||||
```
|
||||
|
||||
## Hub Push Issues
|
||||
|
||||
### Error: Push Failed
|
||||
|
||||
**Symptoms:**
|
||||
```
|
||||
Error pushing to Hub
|
||||
Upload failed
|
||||
```
|
||||
|
||||
**Causes:**
|
||||
- Network issues
|
||||
- Token missing or invalid
|
||||
- Repository access denied
|
||||
- File too large
|
||||
|
||||
**Solutions:**
|
||||
1. Check token: `assert "HF_TOKEN" in os.environ`
|
||||
2. Verify repository exists or can be created
|
||||
3. Check network connectivity in logs
|
||||
4. Retry push operation
|
||||
5. Split large files into chunks
|
||||
|
||||
### Error: Repository Not Found
|
||||
|
||||
**Symptoms:**
|
||||
```
|
||||
404 Client Error: Not Found
|
||||
Repository not found
|
||||
```
|
||||
|
||||
**Causes:**
|
||||
- Repository doesn't exist
|
||||
- Wrong repository name
|
||||
- No access to private repo
|
||||
|
||||
**Solutions:**
|
||||
1. Create repository first:
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
api = HfApi()
|
||||
api.create_repo("username/repo-name", repo_type="dataset")
|
||||
```
|
||||
2. Check repository name format
|
||||
3. Verify namespace exists
|
||||
4. Check repository visibility
|
||||
|
||||
### Error: Results Not Saved
|
||||
|
||||
**Symptoms:**
|
||||
- Job completes successfully
|
||||
- No results visible on Hub
|
||||
- Files not persisted
|
||||
|
||||
**Causes:**
|
||||
- No persistence code in script
|
||||
- Push code not executed
|
||||
- Push failed silently
|
||||
|
||||
**Solutions:**
|
||||
1. Add persistence code to script
|
||||
2. Verify push executes successfully
|
||||
3. Check logs for push errors
|
||||
4. Add error handling around push
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
try:
|
||||
dataset.push_to_hub("username/dataset")
|
||||
print("✅ Push successful")
|
||||
except Exception as e:
|
||||
print(f"❌ Push failed: {e}")
|
||||
raise
|
||||
```
|
||||
|
||||
## Hardware Issues
|
||||
|
||||
### Error: GPU Not Available
|
||||
|
||||
**Symptoms:**
|
||||
```
|
||||
CUDA not available
|
||||
No GPU found
|
||||
```
|
||||
|
||||
**Causes:**
|
||||
- CPU flavor used instead of GPU
|
||||
- GPU not requested
|
||||
- CUDA not installed in image
|
||||
|
||||
**Solutions:**
|
||||
1. Use GPU flavor: `"flavor": "a10g-large"`
|
||||
2. Check image has CUDA support
|
||||
3. Verify GPU availability in logs
|
||||
|
||||
### Error: Slow Performance
|
||||
|
||||
**Symptoms:**
|
||||
- Job takes longer than expected
|
||||
- Low GPU utilization
|
||||
- CPU bottleneck
|
||||
|
||||
**Causes:**
|
||||
- Wrong hardware selected
|
||||
- Inefficient code
|
||||
- Data loading bottleneck
|
||||
|
||||
**Solutions:**
|
||||
1. Upgrade hardware
|
||||
2. Optimize code
|
||||
3. Use batch processing
|
||||
4. Profile code to find bottlenecks
|
||||
|
||||
## General Issues
|
||||
|
||||
### Error: Job Status Unknown
|
||||
|
||||
**Symptoms:**
|
||||
- Can't check job status
|
||||
- Status API returns error
|
||||
|
||||
**Solutions:**
|
||||
1. Use job URL: `https://huggingface.co/jobs/username/job-id`
|
||||
2. Check logs: `hf_jobs("logs", {"job_id": "..."})`
|
||||
3. Inspect job: `hf_jobs("inspect", {"job_id": "..."})`
|
||||
|
||||
### Error: Logs Not Available
|
||||
|
||||
**Symptoms:**
|
||||
- No logs visible
|
||||
- Logs delayed
|
||||
|
||||
**Causes:**
|
||||
- Job just started (logs delayed 30-60s)
|
||||
- Job failed before logging
|
||||
- Logs not yet generated
|
||||
|
||||
**Solutions:**
|
||||
1. Wait 30-60 seconds after job start
|
||||
2. Check job status first
|
||||
3. Use job URL for web interface
|
||||
|
||||
### Error: Cost Unexpectedly High
|
||||
|
||||
**Symptoms:**
|
||||
- Job costs more than expected
|
||||
- Longer runtime than estimated
|
||||
|
||||
**Causes:**
|
||||
- Job ran longer than timeout
|
||||
- Wrong hardware selected
|
||||
- Inefficient code
|
||||
|
||||
**Solutions:**
|
||||
1. Monitor job runtime
|
||||
2. Set appropriate timeout
|
||||
3. Optimize code
|
||||
4. Choose right hardware
|
||||
5. Check cost estimates before running
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
### 1. Add Logging
|
||||
|
||||
```python
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info("Starting processing...")
|
||||
logger.info(f"Processed {count} items")
|
||||
```
|
||||
|
||||
### 2. Verify Environment
|
||||
|
||||
```python
|
||||
import os
|
||||
print(f"Python version: {os.sys.version}")
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
print(f"HF_TOKEN present: {'HF_TOKEN' in os.environ}")
|
||||
```
|
||||
|
||||
### 3. Test Locally First
|
||||
|
||||
Run script locally before submitting to catch errors early:
|
||||
```bash
|
||||
python script.py
|
||||
# Or with uv
|
||||
uv run script.py
|
||||
```
|
||||
|
||||
### 4. Check Job Logs
|
||||
|
||||
**MCP Tool:**
|
||||
```python
|
||||
# View logs
|
||||
hf_jobs("logs", {"job_id": "your-job-id"})
|
||||
```
|
||||
|
||||
**CLI:**
|
||||
```bash
|
||||
hf jobs logs <job-id>
|
||||
```
|
||||
|
||||
**Python API:**
|
||||
```python
|
||||
from huggingface_hub import fetch_job_logs
|
||||
for log in fetch_job_logs(job_id="your-job-id"):
|
||||
print(log)
|
||||
```
|
||||
|
||||
**Or use job URL:** `https://huggingface.co/jobs/username/job-id`
|
||||
|
||||
### 5. Add Error Handling
|
||||
|
||||
```python
|
||||
try:
|
||||
# Your code
|
||||
process_data()
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
```
|
||||
|
||||
### 6. Check Job Status Programmatically
|
||||
|
||||
```python
|
||||
from huggingface_hub import inspect_job, fetch_job_logs
|
||||
|
||||
job_info = inspect_job(job_id="your-job-id")
|
||||
print(f"Status: {job_info.status.stage}")
|
||||
print(f"Message: {job_info.status.message}")
|
||||
|
||||
if job_info.status.stage == "ERROR":
|
||||
print("Job failed! Logs:")
|
||||
for log in fetch_job_logs(job_id="your-job-id"):
|
||||
print(log)
|
||||
```
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Common Error Codes
|
||||
|
||||
| Code | Meaning | Solution |
|
||||
|------|---------|----------|
|
||||
| 401 | Unauthorized | Add token to secrets: MCP uses `"$HF_TOKEN"`, Python API uses `get_token()` |
|
||||
| 403 | Forbidden | Check token permissions |
|
||||
| 404 | Not Found | Verify repository exists |
|
||||
| 500 | Server Error | Retry or contact support |
|
||||
|
||||
### Checklist Before Submitting
|
||||
|
||||
- [ ] Token configured: MCP uses `secrets={"HF_TOKEN": "$HF_TOKEN"}`, Python API uses `secrets={"HF_TOKEN": get_token()}`
|
||||
- [ ] Script checks for token: `assert "HF_TOKEN" in os.environ`
|
||||
- [ ] Timeout set appropriately
|
||||
- [ ] Hardware selected correctly
|
||||
- [ ] Dependencies listed in PEP 723 header
|
||||
- [ ] Persistence code included
|
||||
- [ ] Error handling added
|
||||
- [ ] Logging added for debugging
|
||||
|
||||
## Getting Help
|
||||
|
||||
If issues persist:
|
||||
|
||||
1. **Check logs** - Most errors include detailed messages
|
||||
2. **Review documentation** - See main SKILL.md
|
||||
3. **Check Hub status** - https://status.huggingface.co
|
||||
4. **Community forums** - https://discuss.huggingface.co
|
||||
5. **GitHub issues** - For bugs in huggingface_hub
|
||||
|
||||
## Key Takeaways
|
||||
|
||||
1. **Always include token** - MCP: `secrets={"HF_TOKEN": "$HF_TOKEN"}`, Python API: `secrets={"HF_TOKEN": get_token()}`
|
||||
2. **Set appropriate timeout** - Default 30min may be insufficient
|
||||
3. **Verify persistence** - Results won't persist without code
|
||||
4. **Check logs** - Most issues visible in job logs
|
||||
5. **Test locally** - Catch errors before submitting
|
||||
6. **Add error handling** - Better debugging information
|
||||
7. **Monitor costs** - Set timeouts to avoid unexpected charges
|
||||
|
||||
@@ -0,0 +1,718 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "datasets",
|
||||
# "transformers",
|
||||
# "vllm>=0.6.5",
|
||||
# "huggingface-hub[hf_transfer]",
|
||||
# "torch",
|
||||
# "numpy",
|
||||
# "tqdm",
|
||||
# "scikit-learn",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
Generate high-quality synthetic data using Chain-of-Thought Self-Instruct methodology.
|
||||
|
||||
This script implements the CoT-Self-Instruct approach from the paper "CoT-Self-Instruct:
|
||||
Building high-quality synthetic prompts for reasoning and non-reasoning tasks" (2025).
|
||||
|
||||
It supports two modes:
|
||||
1. Reasoning tasks: Generates both questions and answers with Chain-of-Thought
|
||||
2. Instruction tasks: Generates diverse prompts for general instruction following
|
||||
|
||||
Example usage:
|
||||
# Reasoning tasks with Answer-Consistency filtering
|
||||
uv run cot-self-instruct.py \\
|
||||
--seed-dataset davanstrien/s1k-reasoning \\
|
||||
--output-dataset username/synthetic-math \\
|
||||
--task-type reasoning \\
|
||||
--num-samples 5000 \\
|
||||
--filter-method answer-consistency
|
||||
|
||||
# Instruction tasks with RIP filtering
|
||||
uv run cot-self-instruct.py \\
|
||||
--seed-dataset wildchat-filtered \\
|
||||
--output-dataset username/synthetic-prompts \\
|
||||
--task-type instruction \\
|
||||
--filter-method rip \\
|
||||
--reward-model Nexusflow/Athene-RM-8B
|
||||
|
||||
# HF Jobs execution
|
||||
hf jobs uv run --flavor l4x4 \\
|
||||
--image vllm/vllm-openai \\
|
||||
-e HF_TOKEN=$(python3 -c "from huggingface_hub import get_token; print(get_token())") \\
|
||||
https://huggingface.co/datasets/uv-scripts/synthetic-data/raw/main/cot-self-instruct.py \\
|
||||
[args...]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from huggingface_hub import DatasetCard, login
|
||||
from sklearn.cluster import KMeans
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Enable HF Transfer for faster downloads
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Prompt templates from the paper
|
||||
REASONING_PROMPT_TEMPLATE = """You are a reasoning question generator assistant. Your goal is to create a novel, and challenging reasoning question. You are provided the following seed questions:
|
||||
Seed Question 1: {seed1}
|
||||
Seed Question 2: {seed2}
|
||||
Your task is to:
|
||||
1. Write a brand-new, self-contained reasoning question that meets the following requirements:
|
||||
(a) The question draws inspiration from the seed question without copying it verbatim, remaining novel and of comparable difficulty.
|
||||
(b) The question's final answer should be a single, unambiguous scalar value (e.g., an integer, reduced fraction, exact radical), or another answer type that can be verified in one step (e.g., 'yes/no,' a choice from A to D).
|
||||
2. Then reason step by step, solve the new question and format your output as follows:
|
||||
[New Question Begin]{{your_generated_question}}[New Question End]
|
||||
[Final Answer to New Question Begin]\\boxed{{your_final_answer}}[Final Answer to New Question End]"""
|
||||
|
||||
INSTRUCTION_PROMPT_TEMPLATE = """You are a prompt generator assistant. Your goal is to create diverse and creative synthetic prompts.
|
||||
Please follow the steps below to create synthetic prompts.
|
||||
Step 1: Carefully read #Prompt 1# and #Prompt 2#. Identify and list all the common elements between these two prompts. If no common elements are found, list the main elements from each prompt.
|
||||
Step 2: Develop a comprehensive plan based on the #Common Elements List# or #Main Elements List# from Step 1. This plan will guide the generation of new synthetic prompts that are similar to the original prompts.
|
||||
Step 3: Execute the plan step by step and provide one #Synthetic Prompt#.
|
||||
Please reply strictly in the following format:
|
||||
- Step 1 #Common Elements List# or #Main Elements List#:
|
||||
- Step 2 #Plan#:
|
||||
- Step 3 #Synthetic Prompt#:
|
||||
#Prompt 1#:
|
||||
{prompt1}
|
||||
#Prompt 2#:
|
||||
{prompt2}"""
|
||||
|
||||
|
||||
def check_gpu_availability() -> int:
|
||||
"""Check if CUDA is available and return the number of GPUs."""
|
||||
if not torch.cuda.is_available():
|
||||
logger.error("CUDA is not available. This script requires a GPU.")
|
||||
logger.error(
|
||||
"Please run on a machine with NVIDIA GPU or use HF Jobs with GPU flavor."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
for i in range(num_gpus):
|
||||
gpu_name = torch.cuda.get_device_name(i)
|
||||
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
|
||||
logger.info(f"GPU {i}: {gpu_name} with {gpu_memory:.1f} GB memory")
|
||||
|
||||
return num_gpus
|
||||
|
||||
|
||||
def parse_thinking_output(text: str) -> str:
|
||||
"""Remove thinking tokens from model output."""
|
||||
# Remove <think>...</think> blocks
|
||||
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def extract_reasoning_output(text: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Extract question and answer from reasoning task output."""
|
||||
text = parse_thinking_output(text)
|
||||
|
||||
# Extract question
|
||||
question_match = re.search(r'\[New Question Begin\](.*?)\[New Question End\]', text, re.DOTALL)
|
||||
if not question_match:
|
||||
return None, None
|
||||
question = question_match.group(1).strip()
|
||||
|
||||
# Extract answer
|
||||
answer_match = re.search(r'\[Final Answer to New Question Begin\]\\?boxed\{(.*?)\}\[Final Answer to New Question End\]', text, re.DOTALL)
|
||||
if not answer_match:
|
||||
# Try without \boxed
|
||||
answer_match = re.search(r'\[Final Answer to New Question Begin\](.*?)\[Final Answer to New Question End\]', text, re.DOTALL)
|
||||
|
||||
if not answer_match:
|
||||
return question, None
|
||||
|
||||
answer = answer_match.group(1).strip()
|
||||
return question, answer
|
||||
|
||||
|
||||
def extract_instruction_output(text: str) -> Optional[str]:
|
||||
"""Extract synthetic prompt from instruction task output."""
|
||||
text = parse_thinking_output(text)
|
||||
|
||||
# Look for the synthetic prompt after "Step 3 #Synthetic Prompt#:"
|
||||
match = re.search(r'Step 3 #Synthetic Prompt#:\s*(.+)', text, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return None
|
||||
|
||||
|
||||
def categorize_prompts(prompts: List[str], num_categories: int = 8) -> Dict[int, List[int]]:
|
||||
"""Categorize prompts using clustering for instruction tasks."""
|
||||
from transformers import AutoModel
|
||||
|
||||
logger.info(f"Categorizing {len(prompts)} prompts into {num_categories} categories...")
|
||||
|
||||
# Use a small model for embeddings
|
||||
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
||||
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
||||
|
||||
# Get embeddings
|
||||
embeddings = []
|
||||
for prompt in tqdm(prompts, desc="Computing embeddings"):
|
||||
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
embedding = outputs.last_hidden_state.mean(dim=1).numpy()
|
||||
embeddings.append(embedding[0])
|
||||
|
||||
# Cluster
|
||||
kmeans = KMeans(n_clusters=num_categories, random_state=42)
|
||||
labels = kmeans.fit_predict(embeddings)
|
||||
|
||||
# Group by category
|
||||
categories = {}
|
||||
for idx, label in enumerate(labels):
|
||||
if label not in categories:
|
||||
categories[label] = []
|
||||
categories[label].append(idx)
|
||||
|
||||
return categories
|
||||
|
||||
|
||||
def generate_synthetic_data(
|
||||
llm: LLM,
|
||||
seed_data: List[Dict],
|
||||
task_type: str,
|
||||
num_samples: int,
|
||||
categories: Optional[Dict[int, List[int]]] = None,
|
||||
) -> List[Dict]:
|
||||
"""Generate synthetic data using CoT-Self-Instruct."""
|
||||
synthetic_data = []
|
||||
|
||||
# Set up progress bar
|
||||
pbar = tqdm(total=num_samples, desc="Generating synthetic data")
|
||||
|
||||
while len(synthetic_data) < num_samples:
|
||||
# Sample seed data
|
||||
if task_type == "reasoning":
|
||||
# Random sampling for reasoning tasks
|
||||
seeds = random.sample(seed_data, min(2, len(seed_data)))
|
||||
prompt = REASONING_PROMPT_TEMPLATE.format(
|
||||
seed1=seeds[0].get("question", seeds[0].get("prompt", "")),
|
||||
seed2=seeds[1].get("question", seeds[1].get("prompt", "")) if len(seeds) > 1 else seeds[0].get("question", seeds[0].get("prompt", ""))
|
||||
)
|
||||
else:
|
||||
# Category-aware sampling for instruction tasks
|
||||
if categories:
|
||||
# Pick a random category
|
||||
category = random.choice(list(categories.keys()))
|
||||
category_indices = categories[category]
|
||||
indices = random.sample(category_indices, min(2, len(category_indices)))
|
||||
seeds = [seed_data[i] for i in indices]
|
||||
else:
|
||||
seeds = random.sample(seed_data, min(2, len(seed_data)))
|
||||
|
||||
prompt = INSTRUCTION_PROMPT_TEMPLATE.format(
|
||||
prompt1=seeds[0].get("prompt", seeds[0].get("question", "")),
|
||||
prompt2=seeds[1].get("prompt", seeds[1].get("question", "")) if len(seeds) > 1 else seeds[0].get("prompt", seeds[0].get("question", ""))
|
||||
)
|
||||
|
||||
# Generate
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7 if task_type == "reasoning" else 0.8,
|
||||
top_p=0.95 if task_type == "reasoning" else 0.9,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
outputs = llm.generate([prompt], sampling_params)
|
||||
output_text = outputs[0].outputs[0].text
|
||||
|
||||
# Parse output
|
||||
if task_type == "reasoning":
|
||||
question, answer = extract_reasoning_output(output_text)
|
||||
if question and answer:
|
||||
synthetic_data.append({
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"seed_indices": [seed_data.index(s) for s in seeds],
|
||||
})
|
||||
pbar.update(1)
|
||||
else:
|
||||
synthetic_prompt = extract_instruction_output(output_text)
|
||||
if synthetic_prompt:
|
||||
synthetic_data.append({
|
||||
"prompt": synthetic_prompt,
|
||||
"seed_indices": [seed_data.index(s) for s in seeds],
|
||||
})
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
return synthetic_data
|
||||
|
||||
|
||||
def answer_consistency_filter(
|
||||
llm: LLM,
|
||||
synthetic_data: List[Dict],
|
||||
k_responses: int = 16,
|
||||
threshold: float = 0.5,
|
||||
) -> List[Dict]:
|
||||
"""Filter reasoning tasks using Answer-Consistency."""
|
||||
logger.info(f"Applying Answer-Consistency filter with K={k_responses}")
|
||||
|
||||
filtered_data = []
|
||||
|
||||
for item in tqdm(synthetic_data, desc="Answer-Consistency filtering"):
|
||||
question = item["question"]
|
||||
original_answer = item["answer"]
|
||||
|
||||
# Generate K responses
|
||||
prompts = [question] * k_responses
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.6,
|
||||
top_p=0.95,
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Extract answers
|
||||
answers = []
|
||||
for output in outputs:
|
||||
text = output.outputs[0].text
|
||||
# Try to extract boxed answer
|
||||
match = re.search(r'\\boxed\{(.*?)\}', text)
|
||||
if match:
|
||||
answers.append(match.group(1).strip())
|
||||
|
||||
if not answers:
|
||||
continue
|
||||
|
||||
# Get majority answer
|
||||
answer_counts = Counter(answers)
|
||||
if answer_counts:
|
||||
majority_answer, count = answer_counts.most_common(1)[0]
|
||||
|
||||
# Check if majority answer matches original and meets threshold
|
||||
if (majority_answer == original_answer and
|
||||
count / len(answers) >= threshold):
|
||||
item["consistency_score"] = count / len(answers)
|
||||
filtered_data.append(item)
|
||||
|
||||
logger.info(f"Answer-Consistency: kept {len(filtered_data)}/{len(synthetic_data)} examples")
|
||||
return filtered_data
|
||||
|
||||
|
||||
def rip_filter(
|
||||
llm: LLM,
|
||||
synthetic_data: List[Dict],
|
||||
reward_model_id: str,
|
||||
k_responses: int = 32,
|
||||
threshold: float = 0.5,
|
||||
) -> List[Dict]:
|
||||
"""Filter using Rejecting Instruction Preferences (RIP)."""
|
||||
logger.info(f"Applying RIP filter with K={k_responses} and reward model {reward_model_id}")
|
||||
|
||||
# Note: In a full implementation, you would load and use the actual reward model
|
||||
# For this example, we'll use a placeholder scoring mechanism
|
||||
logger.warning("RIP filtering requires a reward model implementation - using placeholder")
|
||||
|
||||
filtered_data = []
|
||||
|
||||
for item in tqdm(synthetic_data, desc="RIP filtering"):
|
||||
prompt = item.get("prompt", item.get("question", ""))
|
||||
|
||||
# Generate K responses
|
||||
prompts = [prompt] * k_responses
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# In real implementation: score each response with reward model
|
||||
# For now, use length as a proxy (longer responses often score higher)
|
||||
scores = [len(output.outputs[0].text) for output in outputs]
|
||||
|
||||
# Use minimum score as quality indicator
|
||||
min_score = min(scores) if scores else 0
|
||||
normalized_score = min_score / 1000 # Normalize to 0-1 range
|
||||
|
||||
if normalized_score >= threshold:
|
||||
item["rip_score"] = normalized_score
|
||||
filtered_data.append(item)
|
||||
|
||||
logger.info(f"RIP filter: kept {len(filtered_data)}/{len(synthetic_data)} examples")
|
||||
return filtered_data
|
||||
|
||||
|
||||
def create_dataset_card(
|
||||
task_type: str,
|
||||
source_dataset: str,
|
||||
generation_model: str,
|
||||
filter_method: str,
|
||||
num_generated: int,
|
||||
num_filtered: int,
|
||||
generation_time: str,
|
||||
additional_info: Dict = None,
|
||||
) -> str:
|
||||
"""Create a comprehensive dataset card."""
|
||||
filter_info = ""
|
||||
if filter_method == "answer-consistency":
|
||||
filter_info = """
|
||||
### Answer-Consistency Filtering
|
||||
|
||||
This dataset was filtered using Answer-Consistency:
|
||||
- Generated K responses for each synthetic question
|
||||
- Kept only examples where majority answer matched the generated answer
|
||||
- Ensures high-quality, correctly solved problems"""
|
||||
elif filter_method == "rip":
|
||||
filter_info = """
|
||||
### RIP (Rejecting Instruction Preferences) Filtering
|
||||
|
||||
This dataset was filtered using RIP:
|
||||
- Generated K responses for each synthetic prompt
|
||||
- Scored responses using a reward model
|
||||
- Kept only prompts with high minimum scores"""
|
||||
|
||||
return f"""---
|
||||
tags:
|
||||
- synthetic-data
|
||||
- cot-self-instruct
|
||||
- {task_type}
|
||||
- uv-script
|
||||
---
|
||||
|
||||
# CoT-Self-Instruct Synthetic Data
|
||||
|
||||
This dataset contains synthetic {task_type} data generated using the Chain-of-Thought Self-Instruct methodology.
|
||||
|
||||
## Generation Details
|
||||
|
||||
- **Source Dataset**: [{source_dataset}](https://huggingface.co/datasets/{source_dataset})
|
||||
- **Generation Model**: [{generation_model}](https://huggingface.co/{generation_model})
|
||||
- **Task Type**: {task_type}
|
||||
- **Filter Method**: {filter_method}
|
||||
- **Generated Examples**: {num_generated:,}
|
||||
- **After Filtering**: {num_filtered:,} ({(num_filtered/num_generated)*100:.1f}% acceptance rate)
|
||||
- **Generation Date**: {generation_time}
|
||||
{filter_info}
|
||||
|
||||
## Methodology
|
||||
|
||||
Generated using CoT-Self-Instruct, which:
|
||||
1. Uses Chain-of-Thought reasoning to analyze seed examples
|
||||
2. Generates new synthetic examples of similar quality and complexity
|
||||
3. Applies quality filtering to ensure high-quality outputs
|
||||
|
||||
Based on the paper: "CoT-Self-Instruct: Building high-quality synthetic prompts for reasoning and non-reasoning tasks" (2025)
|
||||
|
||||
## Generation Script
|
||||
|
||||
Generated using the CoT-Self-Instruct script from [uv-scripts/synthetic-data](https://huggingface.co/datasets/uv-scripts/synthetic-data).
|
||||
|
||||
To reproduce:
|
||||
```bash
|
||||
uv run https://huggingface.co/datasets/uv-scripts/synthetic-data/raw/main/cot-self-instruct.py \\
|
||||
--seed-dataset {source_dataset} \\
|
||||
--output-dataset <your-dataset> \\
|
||||
--task-type {task_type} \\
|
||||
--generation-model {generation_model} \\
|
||||
--filter-method {filter_method}
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate synthetic data using CoT-Self-Instruct",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
|
||||
# Dataset arguments
|
||||
parser.add_argument(
|
||||
"--seed-dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset ID containing seed examples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset ID for output",
|
||||
)
|
||||
|
||||
# Task configuration
|
||||
parser.add_argument(
|
||||
"--task-type",
|
||||
type=str,
|
||||
choices=["reasoning", "instruction", "auto"],
|
||||
default="auto",
|
||||
help="Type of task (reasoning generates Q&A, instruction generates prompts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-column",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Column name containing tasks (auto-detected if not specified)",
|
||||
)
|
||||
|
||||
# Model configuration
|
||||
parser.add_argument(
|
||||
"--generation-model",
|
||||
type=str,
|
||||
default="Qwen/Qwen3-30B-A3B-Thinking-2507",
|
||||
help="Model for synthetic data generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter-model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model for filtering (defaults to generation model)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-model",
|
||||
type=str,
|
||||
default="Nexusflow/Athene-RM-8B",
|
||||
help="Reward model for RIP filtering",
|
||||
)
|
||||
|
||||
# Generation parameters
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Number of synthetic examples to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size for generation",
|
||||
)
|
||||
|
||||
# Filtering parameters
|
||||
parser.add_argument(
|
||||
"--filter-method",
|
||||
type=str,
|
||||
choices=["answer-consistency", "rip", "both", "none"],
|
||||
default="answer-consistency",
|
||||
help="Quality filtering method",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k-responses",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of responses for filtering",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quality-threshold",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Minimum quality threshold for filtering",
|
||||
)
|
||||
|
||||
# GPU configuration
|
||||
parser.add_argument(
|
||||
"--tensor-parallel-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of GPUs for tensor parallelism (auto-detected if not set)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="GPU memory utilization",
|
||||
)
|
||||
|
||||
# Other arguments
|
||||
parser.add_argument(
|
||||
"--hf-token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="HuggingFace API token",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Random seed",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set random seeds
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
# Check GPU
|
||||
num_gpus = check_gpu_availability()
|
||||
tensor_parallel_size = args.tensor_parallel_size or num_gpus
|
||||
|
||||
# Authentication
|
||||
hf_token = args.hf_token or os.environ.get("HF_TOKEN")
|
||||
if hf_token:
|
||||
login(token=hf_token)
|
||||
|
||||
# Load seed dataset
|
||||
logger.info(f"Loading seed dataset: {args.seed_dataset}")
|
||||
seed_dataset = load_dataset(args.seed_dataset, split="train")
|
||||
|
||||
# Auto-detect task type and column if needed
|
||||
if args.task_type == "auto":
|
||||
columns = seed_dataset.column_names
|
||||
if "question" in columns and "answer" in columns:
|
||||
args.task_type = "reasoning"
|
||||
logger.info("Auto-detected task type: reasoning")
|
||||
else:
|
||||
args.task_type = "instruction"
|
||||
logger.info("Auto-detected task type: instruction")
|
||||
|
||||
if not args.task_column:
|
||||
if args.task_type == "reasoning":
|
||||
args.task_column = "question"
|
||||
else:
|
||||
# Try to find prompt column
|
||||
for col in ["prompt", "instruction", "text", "input"]:
|
||||
if col in seed_dataset.column_names:
|
||||
args.task_column = col
|
||||
break
|
||||
|
||||
logger.info(f"Using task column: {args.task_column}")
|
||||
|
||||
# Convert to list of dicts
|
||||
seed_data = seed_dataset.to_list()
|
||||
|
||||
# Categorize prompts for instruction tasks
|
||||
categories = None
|
||||
if args.task_type == "instruction" and len(seed_data) > 100:
|
||||
prompts = [item.get(args.task_column, "") for item in seed_data]
|
||||
categories = categorize_prompts(prompts)
|
||||
|
||||
# Initialize generation model
|
||||
logger.info(f"Loading generation model: {args.generation_model}")
|
||||
generation_llm = LLM(
|
||||
model=args.generation_model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
)
|
||||
|
||||
# Generate synthetic data
|
||||
start_time = datetime.now()
|
||||
synthetic_data = generate_synthetic_data(
|
||||
generation_llm,
|
||||
seed_data,
|
||||
args.task_type,
|
||||
args.num_samples,
|
||||
categories,
|
||||
)
|
||||
|
||||
# Apply filtering
|
||||
filter_llm = generation_llm
|
||||
if args.filter_model and args.filter_model != args.generation_model:
|
||||
logger.info(f"Loading filter model: {args.filter_model}")
|
||||
# Clean up generation model
|
||||
del generation_llm
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
filter_llm = LLM(
|
||||
model=args.filter_model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
)
|
||||
|
||||
filtered_data = synthetic_data
|
||||
if args.filter_method != "none":
|
||||
if args.filter_method == "answer-consistency" and args.task_type == "reasoning":
|
||||
filtered_data = answer_consistency_filter(
|
||||
filter_llm,
|
||||
synthetic_data,
|
||||
args.k_responses,
|
||||
args.quality_threshold,
|
||||
)
|
||||
elif args.filter_method == "rip":
|
||||
filtered_data = rip_filter(
|
||||
filter_llm,
|
||||
synthetic_data,
|
||||
args.reward_model,
|
||||
args.k_responses,
|
||||
args.quality_threshold,
|
||||
)
|
||||
elif args.filter_method == "both":
|
||||
if args.task_type == "reasoning":
|
||||
filtered_data = answer_consistency_filter(
|
||||
filter_llm,
|
||||
synthetic_data,
|
||||
args.k_responses,
|
||||
args.quality_threshold,
|
||||
)
|
||||
filtered_data = rip_filter(
|
||||
filter_llm,
|
||||
filtered_data,
|
||||
args.reward_model,
|
||||
args.k_responses,
|
||||
args.quality_threshold,
|
||||
)
|
||||
|
||||
# Create HuggingFace dataset
|
||||
logger.info(f"Creating dataset with {len(filtered_data)} examples")
|
||||
dataset = Dataset.from_list(filtered_data)
|
||||
|
||||
# Create dataset card
|
||||
generation_time = start_time.strftime("%Y-%m-%d %H:%M:%S UTC")
|
||||
dataset_card = create_dataset_card(
|
||||
args.task_type,
|
||||
args.seed_dataset,
|
||||
args.generation_model,
|
||||
args.filter_method,
|
||||
len(synthetic_data),
|
||||
len(filtered_data),
|
||||
generation_time,
|
||||
)
|
||||
|
||||
# Push to hub
|
||||
logger.info(f"Pushing dataset to: {args.output_dataset}")
|
||||
# Create dataset card
|
||||
card = DatasetCard(dataset_card)
|
||||
dataset.push_to_hub(args.output_dataset)
|
||||
# Push card separately
|
||||
card.push_to_hub(args.output_dataset)
|
||||
|
||||
logger.info("Done! Dataset available at: https://huggingface.co/datasets/" + args.output_dataset)
|
||||
|
||||
# Print example HF Jobs command if running locally
|
||||
if len(sys.argv) > 1:
|
||||
print("\nTo run on HF Jobs:")
|
||||
print(f"""hf jobs uv run --flavor l4x4 \\
|
||||
--image vllm/vllm-openai \\
|
||||
-e HF_TOKEN=$(python3 -c "from huggingface_hub import get_token; print(get_token())") \\
|
||||
https://huggingface.co/datasets/uv-scripts/synthetic-data/raw/main/cot-self-instruct.py \\
|
||||
--seed-dataset {args.seed_dataset} \\
|
||||
--output-dataset {args.output_dataset} \\
|
||||
--task-type {args.task_type} \\
|
||||
--generation-model {args.generation_model} \\
|
||||
--filter-method {args.filter_method} \\
|
||||
--num-samples {args.num_samples}""")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,546 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "polars>=1.31.0",
|
||||
# "huggingface-hub",
|
||||
# "datasets",
|
||||
# "ascii-graph",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
Analyze educational quality trends across CommonCrawl dumps using Polars streaming.
|
||||
|
||||
Answers: "Is the web getting more educational over time?"
|
||||
|
||||
Demonstrates Polars HF Hub integration - process 50M+ docs without downloading 300GB+.
|
||||
|
||||
Example usage:
|
||||
# Analyze English PDFs (default)
|
||||
uv run finepdfs-stats.py
|
||||
|
||||
# Analyze all 70+ languages
|
||||
uv run finepdfs-stats.py --all-languages
|
||||
|
||||
# Quick test
|
||||
uv run finepdfs-stats.py --limit 10000 --show-plan
|
||||
|
||||
# Save results to HF Hub
|
||||
uv run finepdfs-stats.py --output-repo username/finepdfs-temporal-stats
|
||||
|
||||
# Run on HF Jobs
|
||||
hf jobs uv run \\
|
||||
-s HF_TOKEN \\
|
||||
-e HF_XET_HIGH_PERFORMANCE=1 \\
|
||||
https://huggingface.co/datasets/uv-scripts/dataset-stats/raw/main/finepdfs-stats.py \\
|
||||
-- --output-repo username/stats
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import polars as pl
|
||||
from ascii_graph import Pyasciigraph
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import HfApi, create_repo, list_repo_tree, login
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Common language+script codes for finepdfs-edu
|
||||
COMMON_LANGUAGES = {
|
||||
"eng_Latn": "English (Latin script)",
|
||||
"fra_Latn": "French (Latin script)",
|
||||
"deu_Latn": "German (Latin script)",
|
||||
"spa_Latn": "Spanish (Latin script)",
|
||||
"por_Latn": "Portuguese (Latin script)",
|
||||
"ita_Latn": "Italian (Latin script)",
|
||||
"nld_Latn": "Dutch (Latin script)",
|
||||
"pol_Latn": "Polish (Latin script)",
|
||||
"rus_Cyrl": "Russian (Cyrillic script)",
|
||||
"zho_Hans": "Chinese (Simplified)",
|
||||
"zho_Hant": "Chinese (Traditional)",
|
||||
"jpn_Jpan": "Japanese",
|
||||
"kor_Hang": "Korean",
|
||||
"ara_Arab": "Arabic",
|
||||
"hin_Deva": "Hindi (Devanagari)",
|
||||
}
|
||||
|
||||
|
||||
def list_available_languages(dataset_id: str) -> list[str]:
|
||||
"""List available language subsets in the dataset."""
|
||||
try:
|
||||
tree = list_repo_tree(dataset_id, path_in_repo="data", repo_type="dataset")
|
||||
languages = [
|
||||
item.path.replace("data/", "")
|
||||
for item in tree
|
||||
if item.path.startswith("data/")
|
||||
and "/" not in item.path.replace("data/", "")
|
||||
]
|
||||
return sorted(languages)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not list languages: {e}")
|
||||
return list(COMMON_LANGUAGES.keys())
|
||||
|
||||
|
||||
def compute_temporal_stats(df: pl.LazyFrame, output_path: Path) -> pl.DataFrame:
|
||||
"""Single scan: compute stats grouped by dump for temporal analysis."""
|
||||
query = df.group_by("dump").agg(
|
||||
pl.len().alias("doc_count"),
|
||||
pl.col("token_count").sum().alias("total_tokens"),
|
||||
pl.col("fw_edu_scores").list.mean().mean().alias("avg_edu_score"),
|
||||
(pl.col("fw_edu_scores").list.mean() >= 3).sum().alias("high_edu_count"),
|
||||
)
|
||||
query.sink_parquet(output_path, engine="streaming")
|
||||
return pl.read_parquet(output_path)
|
||||
|
||||
|
||||
def compute_global_stats(temporal: pl.DataFrame) -> pl.DataFrame:
|
||||
"""Compute global stats from temporal breakdown."""
|
||||
total = temporal["doc_count"].sum()
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"total_docs": [total],
|
||||
"total_tokens": [temporal["total_tokens"].sum()],
|
||||
"avg_edu_score": [
|
||||
(temporal["avg_edu_score"] * temporal["doc_count"]).sum() / total
|
||||
],
|
||||
"high_edu_rate": [temporal["high_edu_count"].sum() / total],
|
||||
"num_dumps": [len(temporal)],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def format_temporal_stats(temporal: pl.DataFrame) -> pl.DataFrame:
|
||||
"""Format temporal stats with high_edu_rate, sorted chronologically."""
|
||||
return (
|
||||
temporal.with_columns(
|
||||
(pl.col("high_edu_count") / pl.col("doc_count")).alias("high_edu_rate")
|
||||
)
|
||||
.select(["dump", "doc_count", "avg_edu_score", "high_edu_rate"])
|
||||
.sort(
|
||||
"dump"
|
||||
) # Chronological order (CC-MAIN-2017-xx comes before CC-MAIN-2024-xx)
|
||||
)
|
||||
|
||||
|
||||
def create_ascii_charts(temporal_stats: pl.DataFrame) -> str:
|
||||
"""Create ASCII bar charts showing temporal trends."""
|
||||
# Extract year from dump name (CC-MAIN-2024-42 -> 2024)
|
||||
# Group by year and average the values for cleaner display
|
||||
yearly = (
|
||||
temporal_stats.with_columns(
|
||||
pl.col("dump").str.extract(r"CC-MAIN-(\d{4})", 1).alias("year")
|
||||
)
|
||||
.group_by("year")
|
||||
.agg(
|
||||
pl.col("doc_count").sum(),
|
||||
pl.col("avg_edu_score").mean(),
|
||||
pl.col("high_edu_rate").mean(),
|
||||
)
|
||||
.sort("year")
|
||||
)
|
||||
|
||||
lines = []
|
||||
|
||||
# High edu rate chart (more dramatic differences)
|
||||
data_rate = [
|
||||
(row["year"], row["high_edu_rate"] * 100)
|
||||
for row in yearly.iter_rows(named=True)
|
||||
]
|
||||
graph = Pyasciigraph(line_length=60, float_format="{0:.1f}%")
|
||||
lines.extend(graph.graph("High Educational Content (edu >= 3)", data_rate))
|
||||
|
||||
lines.append("")
|
||||
|
||||
# Avg edu score chart
|
||||
data_score = [
|
||||
(row["year"], row["avg_edu_score"]) for row in yearly.iter_rows(named=True)
|
||||
]
|
||||
graph2 = Pyasciigraph(line_length=60, float_format="{0:.2f}")
|
||||
lines.extend(graph2.graph("Average Educational Score", data_score))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def create_readme(
|
||||
args,
|
||||
global_stats: pl.DataFrame,
|
||||
temporal_stats: pl.DataFrame,
|
||||
scan_time: float,
|
||||
ascii_charts: str,
|
||||
) -> str:
|
||||
"""Create README content for the stats dataset."""
|
||||
stats = global_stats.to_dicts()[0]
|
||||
total_docs = stats.get("total_docs", 0)
|
||||
docs_per_sec = total_docs / scan_time if scan_time > 0 else 0
|
||||
|
||||
# Get first and last year averages for trend (more representative than single dumps)
|
||||
yearly = (
|
||||
temporal_stats.with_columns(
|
||||
pl.col("dump").str.extract(r"CC-MAIN-(\d{4})", 1).alias("year")
|
||||
)
|
||||
.group_by("year")
|
||||
.agg(
|
||||
pl.col("doc_count").sum(),
|
||||
pl.col("avg_edu_score").mean(),
|
||||
pl.col("high_edu_rate").mean(),
|
||||
)
|
||||
.sort("year")
|
||||
)
|
||||
first_year = yearly.head(1).to_dicts()[0]
|
||||
last_year = yearly.tail(1).to_dicts()[0]
|
||||
|
||||
scope = (
|
||||
"all languages"
|
||||
if args.all_languages
|
||||
else COMMON_LANGUAGES.get(args.lang, args.lang)
|
||||
)
|
||||
|
||||
return f"""---
|
||||
tags:
|
||||
- uv-script
|
||||
- statistics
|
||||
- polars
|
||||
- finepdfs-edu
|
||||
- temporal-analysis
|
||||
license: odc-by
|
||||
configs:
|
||||
- config_name: global_stats
|
||||
data_files: global_stats/train-*.parquet
|
||||
- config_name: temporal_stats
|
||||
data_files: temporal_stats/train-*.parquet
|
||||
default_viewer_config: temporal_stats
|
||||
---
|
||||
|
||||
# Is the Web Getting More Educational?
|
||||
|
||||
Temporal analysis of educational quality in **{scope}** across {stats.get("num_dumps", 0)} CommonCrawl dumps.
|
||||
|
||||
## Trend
|
||||
|
||||
```
|
||||
{ascii_charts}
|
||||
```
|
||||
|
||||
## Key Finding
|
||||
|
||||
| Year | Avg Edu Score | High Edu Rate |
|
||||
|------|---------------|---------------|
|
||||
| {first_year["year"]} | {first_year["avg_edu_score"]:.2f} | {first_year["high_edu_rate"] * 100:.1f}% |
|
||||
| {last_year["year"]} | {last_year["avg_edu_score"]:.2f} | {last_year["high_edu_rate"] * 100:.1f}% |
|
||||
|
||||
## Performance
|
||||
|
||||
- **{total_docs:,} documents** processed in **{scan_time:.0f} seconds**
|
||||
- **{docs_per_sec:,.0f} docs/sec** using Polars streaming
|
||||
- Single scan, no full dataset download required
|
||||
|
||||
## Summary
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| Scope | {scope} |
|
||||
| Total Documents | {total_docs:,} |
|
||||
| Total Tokens | {stats.get("total_tokens", 0):,} |
|
||||
| Avg Edu Score | {stats.get("avg_edu_score", 0):.3f} |
|
||||
| High Edu Rate | {stats.get("high_edu_rate", 0) * 100:.1f}% |
|
||||
| CommonCrawl Dumps | {stats.get("num_dumps", 0)} |
|
||||
|
||||
## Files
|
||||
|
||||
- `global_stats` - Overall summary
|
||||
- `temporal_stats` - Per-dump breakdown (sorted chronologically)
|
||||
|
||||
## Reproduce
|
||||
|
||||
```bash
|
||||
uv run https://huggingface.co/datasets/uv-scripts/dataset-stats/raw/main/finepdfs-stats.py \\
|
||||
{"--all-languages" if args.all_languages else f"--lang {args.lang}"} --output-repo your-username/stats
|
||||
```
|
||||
|
||||
## Source
|
||||
|
||||
- **Dataset**: [{args.source_dataset}](https://huggingface.co/datasets/{args.source_dataset})
|
||||
- **Script**: [uv-scripts/dataset-stats](https://huggingface.co/datasets/uv-scripts/dataset-stats)
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Analyze educational quality trends across CommonCrawl dumps",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--source-dataset",
|
||||
type=str,
|
||||
default="HuggingFaceFW/finepdfs-edu",
|
||||
help="Source dataset (default: HuggingFaceFW/finepdfs-edu)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang",
|
||||
type=str,
|
||||
default="eng_Latn",
|
||||
help="Language+script code (default: eng_Latn)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--all-languages",
|
||||
action="store_true",
|
||||
help="Analyze all languages (70+) instead of single language",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--show-plan",
|
||||
action="store_true",
|
||||
help="Show Polars query plan (demonstrates optimization)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--list-languages",
|
||||
action="store_true",
|
||||
help="List available languages and exit",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
help="Limit to first N rows (for testing)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-repo",
|
||||
type=str,
|
||||
help="HuggingFace dataset repository to upload results",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./stats_output",
|
||||
help="Local directory for output files",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hf-token",
|
||||
type=str,
|
||||
help="HuggingFace API token (or set HF_TOKEN env var)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
help="Make the output dataset private",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check for high-performance mode
|
||||
if os.environ.get("HF_XET_HIGH_PERFORMANCE"):
|
||||
logger.info("High-performance mode enabled (HF_XET_HIGH_PERFORMANCE=1)")
|
||||
|
||||
# List languages mode
|
||||
if args.list_languages:
|
||||
print(f"Available language+script codes for {args.source_dataset}:\n")
|
||||
print("Common languages:")
|
||||
for code, name in COMMON_LANGUAGES.items():
|
||||
print(f" {code:12} - {name}")
|
||||
print("\nFetching full list from HF Hub...")
|
||||
all_langs = list_available_languages(args.source_dataset)
|
||||
print(f"\nAll available ({len(all_langs)} total):")
|
||||
for lang in all_langs[:30]: # Show first 30
|
||||
name = COMMON_LANGUAGES.get(lang, "")
|
||||
print(f" {lang:12} {name}")
|
||||
if len(all_langs) > 30:
|
||||
print(f" ... and {len(all_langs) - 30} more")
|
||||
sys.exit(0)
|
||||
|
||||
# Build the parquet path
|
||||
if args.all_languages:
|
||||
source_path = f"hf://datasets/{args.source_dataset}/data/*/train/*.parquet"
|
||||
scope_desc = "all languages"
|
||||
else:
|
||||
source_path = (
|
||||
f"hf://datasets/{args.source_dataset}/data/{args.lang}/train/*.parquet"
|
||||
)
|
||||
scope_desc = f"{args.lang} ({COMMON_LANGUAGES.get(args.lang, 'unknown')})"
|
||||
|
||||
logger.info(f"Scanning: {source_path}")
|
||||
logger.info(f"Scope: {scope_desc}")
|
||||
|
||||
# Create lazy frame - this doesn't load any data yet!
|
||||
logger.info("Creating lazy query plan...")
|
||||
df = pl.scan_parquet(source_path)
|
||||
|
||||
# Apply limit if specified
|
||||
if args.limit:
|
||||
logger.info(f"Limiting to first {args.limit:,} rows")
|
||||
df = df.head(args.limit)
|
||||
|
||||
# Show query plan if requested
|
||||
if args.show_plan:
|
||||
# Build a sample query to show the plan
|
||||
sample_query = df.select(
|
||||
pl.len(),
|
||||
pl.col("token_count").sum(),
|
||||
pl.col("language").n_unique(),
|
||||
)
|
||||
print("\nQuery Plan (showing Polars optimization):")
|
||||
print("=" * 60)
|
||||
print(sample_query.explain())
|
||||
print("=" * 60)
|
||||
print("\nNote: Polars uses projection pushdown - only reads columns needed!")
|
||||
print("The 'text' column is never loaded, making this very fast.\n")
|
||||
|
||||
# Create output directory
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Single scan: compute temporal stats
|
||||
logger.info("Computing temporal stats (single scan)...")
|
||||
start = time.perf_counter()
|
||||
temporal_path = output_dir / "temporal_stats.parquet"
|
||||
temporal_raw = compute_temporal_stats(df, temporal_path)
|
||||
scan_time = time.perf_counter() - start
|
||||
logger.info(f"Scan complete in {scan_time:.2f}s - {len(temporal_raw)} dumps")
|
||||
|
||||
# Compute stats
|
||||
global_stats = compute_global_stats(temporal_raw)
|
||||
temporal_stats = format_temporal_stats(temporal_raw)
|
||||
|
||||
# Save
|
||||
global_stats.write_parquet(output_dir / "global_stats.parquet")
|
||||
temporal_stats.write_parquet(output_dir / "temporal_stats.parquet")
|
||||
|
||||
# Print results
|
||||
total_docs = global_stats["total_docs"][0]
|
||||
docs_per_sec = total_docs / scan_time if scan_time > 0 else 0
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("IS THE WEB GETTING MORE EDUCATIONAL?")
|
||||
print("=" * 70)
|
||||
|
||||
print(f"\nScope: {scope_desc}")
|
||||
print(f"Dataset: {args.source_dataset}")
|
||||
|
||||
print("\n" + "-" * 70)
|
||||
print("GLOBAL STATS")
|
||||
print("-" * 70)
|
||||
print(global_stats)
|
||||
|
||||
print("\n" + "-" * 70)
|
||||
print(f"TEMPORAL TREND ({len(temporal_stats)} CommonCrawl dumps)")
|
||||
print("-" * 70)
|
||||
# Show first 5 and last 5
|
||||
if len(temporal_stats) > 10:
|
||||
print("Earliest dumps:")
|
||||
print(temporal_stats.head(5))
|
||||
print("\n...")
|
||||
print("\nLatest dumps:")
|
||||
print(temporal_stats.tail(5))
|
||||
else:
|
||||
print(temporal_stats)
|
||||
|
||||
# Create ASCII charts
|
||||
ascii_charts = create_ascii_charts(temporal_stats)
|
||||
print("\n" + "-" * 70)
|
||||
print("TREND VISUALIZATION")
|
||||
print("-" * 70)
|
||||
print(ascii_charts)
|
||||
|
||||
print("\n" + "-" * 70)
|
||||
print("PERFORMANCE")
|
||||
print("-" * 70)
|
||||
print(f"Scan time: {scan_time:.2f}s")
|
||||
print(f"Documents: {total_docs:,}")
|
||||
print(f"Throughput: {docs_per_sec:,.0f} docs/sec")
|
||||
|
||||
logger.info(f"Results saved to: {output_dir}")
|
||||
|
||||
# Upload to HF Hub if requested
|
||||
if args.output_repo:
|
||||
hf_token = args.hf_token or os.environ.get("HF_TOKEN")
|
||||
if hf_token:
|
||||
login(token=hf_token)
|
||||
|
||||
api = HfApi(token=hf_token)
|
||||
|
||||
logger.info(f"Creating/updating dataset repository: {args.output_repo}")
|
||||
create_repo(
|
||||
args.output_repo,
|
||||
repo_type="dataset",
|
||||
private=args.private,
|
||||
token=hf_token,
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
# Upload each as a dataset config
|
||||
configs = [
|
||||
("global_stats", global_stats),
|
||||
("temporal_stats", temporal_stats),
|
||||
]
|
||||
|
||||
for config_name, stats_df in configs:
|
||||
logger.info(f"Uploading {config_name}...")
|
||||
ds = Dataset.from_polars(stats_df)
|
||||
ds.push_to_hub(
|
||||
args.output_repo,
|
||||
config_name=config_name,
|
||||
token=hf_token,
|
||||
private=args.private,
|
||||
)
|
||||
time.sleep(1) # Avoid 409 conflicts
|
||||
|
||||
# Upload README
|
||||
readme_content = create_readme(
|
||||
args, global_stats, temporal_stats, scan_time, ascii_charts
|
||||
)
|
||||
api.upload_file(
|
||||
path_or_fileobj=readme_content.encode(),
|
||||
path_in_repo="README.md",
|
||||
repo_id=args.output_repo,
|
||||
repo_type="dataset",
|
||||
token=hf_token,
|
||||
)
|
||||
|
||||
dataset_url = f"https://huggingface.co/datasets/{args.output_repo}"
|
||||
logger.info(f"Dataset uploaded: {dataset_url}")
|
||||
print(f"\nResults uploaded to: {dataset_url}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) == 1:
|
||||
print("Is the Web Getting More Educational?")
|
||||
print("=" * 40)
|
||||
print("\nAnalyze educational quality trends across CommonCrawl dumps")
|
||||
print("using Polars streaming - no download needed!\n")
|
||||
print("Example commands:\n")
|
||||
print("# Quick test:")
|
||||
print("uv run finepdfs-stats.py --limit 10000\n")
|
||||
print("# Analyze English PDFs:")
|
||||
print("uv run finepdfs-stats.py\n")
|
||||
print("# Analyze ALL 70+ languages:")
|
||||
print("uv run finepdfs-stats.py --all-languages\n")
|
||||
print("# Show query plan (see Polars optimization):")
|
||||
print("uv run finepdfs-stats.py --show-plan --limit 1000\n")
|
||||
print("# Save results to HF Hub:")
|
||||
print("uv run finepdfs-stats.py --output-repo username/temporal-stats\n")
|
||||
print("# Run on HF Jobs:")
|
||||
print("hf jobs uv run \\")
|
||||
print(" -s HF_TOKEN \\")
|
||||
print(" -e HF_XET_HIGH_PERFORMANCE=1 \\")
|
||||
print(
|
||||
" https://huggingface.co/datasets/uv-scripts/dataset-stats/raw/main/finepdfs-stats.py \\"
|
||||
)
|
||||
print(" -- --output-repo username/stats")
|
||||
sys.exit(0)
|
||||
|
||||
main()
|
||||
@@ -0,0 +1,587 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "datasets",
|
||||
# "flashinfer-python",
|
||||
# "huggingface-hub[hf_transfer]",
|
||||
# "hf-xet>= 1.1.7",
|
||||
# "torch",
|
||||
# "transformers",
|
||||
# "vllm>=0.8.5",
|
||||
# ]
|
||||
#
|
||||
# ///
|
||||
"""
|
||||
Generate responses for prompts in a dataset using vLLM for efficient GPU inference.
|
||||
|
||||
This script loads a dataset from Hugging Face Hub containing chat-formatted messages,
|
||||
applies the model's chat template, generates responses using vLLM, and saves the
|
||||
results back to the Hub with a comprehensive dataset card.
|
||||
|
||||
Example usage:
|
||||
# Local execution with auto GPU detection
|
||||
uv run generate-responses.py \\
|
||||
username/input-dataset \\
|
||||
username/output-dataset \\
|
||||
--messages-column messages
|
||||
|
||||
# With custom model and sampling parameters
|
||||
uv run generate-responses.py \\
|
||||
username/input-dataset \\
|
||||
username/output-dataset \\
|
||||
--model-id meta-llama/Llama-3.1-8B-Instruct \\
|
||||
--temperature 0.9 \\
|
||||
--top-p 0.95 \\
|
||||
--max-tokens 2048
|
||||
|
||||
# HF Jobs execution (see script output for full command)
|
||||
hf jobs uv run --flavor a100x4 ...
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import DatasetCard, get_token, login
|
||||
from torch import cuda
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Enable HF Transfer for faster downloads
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_gpu_availability() -> int:
|
||||
"""Check if CUDA is available and return the number of GPUs."""
|
||||
if not cuda.is_available():
|
||||
logger.error("CUDA is not available. This script requires a GPU.")
|
||||
logger.error(
|
||||
"Please run on a machine with NVIDIA GPU or use HF Jobs with GPU flavor."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
num_gpus = cuda.device_count()
|
||||
for i in range(num_gpus):
|
||||
gpu_name = cuda.get_device_name(i)
|
||||
gpu_memory = cuda.get_device_properties(i).total_memory / 1024**3
|
||||
logger.info(f"GPU {i}: {gpu_name} with {gpu_memory:.1f} GB memory")
|
||||
|
||||
return num_gpus
|
||||
|
||||
|
||||
def create_dataset_card(
|
||||
source_dataset: str,
|
||||
model_id: str,
|
||||
messages_column: str,
|
||||
prompt_column: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
tensor_parallel_size: int,
|
||||
num_examples: int,
|
||||
generation_time: str,
|
||||
num_skipped: int = 0,
|
||||
max_model_len_used: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Create a comprehensive dataset card documenting the generation process."""
|
||||
filtering_section = ""
|
||||
if num_skipped > 0:
|
||||
skip_percentage = (num_skipped / num_examples) * 100
|
||||
processed = num_examples - num_skipped
|
||||
filtering_section = f"""
|
||||
|
||||
### Filtering Statistics
|
||||
|
||||
- **Total Examples**: {num_examples:,}
|
||||
- **Processed**: {processed:,} ({100 - skip_percentage:.1f}%)
|
||||
- **Skipped (too long)**: {num_skipped:,} ({skip_percentage:.1f}%)
|
||||
- **Max Model Length Used**: {max_model_len_used:,} tokens
|
||||
|
||||
Note: Prompts exceeding the maximum model length were skipped and have empty responses."""
|
||||
|
||||
return f"""---
|
||||
tags:
|
||||
- generated
|
||||
- vllm
|
||||
- uv-script
|
||||
---
|
||||
|
||||
# Generated Responses Dataset
|
||||
|
||||
This dataset contains generated responses for prompts from [{source_dataset}](https://huggingface.co/datasets/{source_dataset}).
|
||||
|
||||
## Generation Details
|
||||
|
||||
- **Source Dataset**: [{source_dataset}](https://huggingface.co/datasets/{source_dataset})
|
||||
- **Input Column**: `{prompt_column if prompt_column else messages_column}` ({"plain text prompts" if prompt_column else "chat messages"})
|
||||
- **Model**: [{model_id}](https://huggingface.co/{model_id})
|
||||
- **Number of Examples**: {num_examples:,}
|
||||
- **Generation Date**: {generation_time}{filtering_section}
|
||||
|
||||
### Sampling Parameters
|
||||
|
||||
- **Temperature**: {sampling_params.temperature}
|
||||
- **Top P**: {sampling_params.top_p}
|
||||
- **Top K**: {sampling_params.top_k}
|
||||
- **Min P**: {sampling_params.min_p}
|
||||
- **Max Tokens**: {sampling_params.max_tokens}
|
||||
- **Repetition Penalty**: {sampling_params.repetition_penalty}
|
||||
|
||||
### Hardware Configuration
|
||||
|
||||
- **Tensor Parallel Size**: {tensor_parallel_size}
|
||||
- **GPU Configuration**: {tensor_parallel_size} GPU(s)
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
The dataset contains all columns from the source dataset plus:
|
||||
- `response`: The generated response from the model
|
||||
|
||||
## Generation Script
|
||||
|
||||
Generated using the vLLM inference script from [uv-scripts/vllm](https://huggingface.co/datasets/uv-scripts/vllm).
|
||||
|
||||
To reproduce this generation:
|
||||
|
||||
```bash
|
||||
uv run https://huggingface.co/datasets/uv-scripts/vllm/raw/main/generate-responses.py \\
|
||||
{source_dataset} \\
|
||||
<output-dataset> \\
|
||||
--model-id {model_id} \\
|
||||
{"--prompt-column " + prompt_column if prompt_column else "--messages-column " + messages_column} \\
|
||||
--temperature {sampling_params.temperature} \\
|
||||
--top-p {sampling_params.top_p} \\
|
||||
--top-k {sampling_params.top_k} \\
|
||||
--max-tokens {sampling_params.max_tokens}{f" \\\\\\n --max-model-len {max_model_len_used}" if max_model_len_used else ""}
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def main(
|
||||
src_dataset_hub_id: str,
|
||||
output_dataset_hub_id: str,
|
||||
model_id: str = "Qwen/Qwen3-30B-A3B-Instruct-2507",
|
||||
messages_column: str = "messages",
|
||||
prompt_column: Optional[str] = None,
|
||||
output_column: str = "response",
|
||||
temperature: float = 0.7,
|
||||
top_p: float = 0.8,
|
||||
top_k: int = 20,
|
||||
min_p: float = 0.0,
|
||||
max_tokens: int = 16384,
|
||||
repetition_penalty: float = 1.0,
|
||||
gpu_memory_utilization: float = 0.90,
|
||||
max_model_len: Optional[int] = None,
|
||||
tensor_parallel_size: Optional[int] = None,
|
||||
skip_long_prompts: bool = True,
|
||||
max_samples: Optional[int] = None,
|
||||
hf_token: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Main generation pipeline.
|
||||
|
||||
Args:
|
||||
src_dataset_hub_id: Input dataset on Hugging Face Hub
|
||||
output_dataset_hub_id: Where to save results on Hugging Face Hub
|
||||
model_id: Hugging Face model ID for generation
|
||||
messages_column: Column name containing chat messages
|
||||
prompt_column: Column name containing plain text prompts (alternative to messages_column)
|
||||
output_column: Column name for generated responses
|
||||
temperature: Sampling temperature
|
||||
top_p: Top-p sampling parameter
|
||||
top_k: Top-k sampling parameter
|
||||
min_p: Minimum probability threshold
|
||||
max_tokens: Maximum tokens to generate
|
||||
repetition_penalty: Repetition penalty parameter
|
||||
gpu_memory_utilization: GPU memory utilization factor
|
||||
max_model_len: Maximum model context length (None uses model default)
|
||||
tensor_parallel_size: Number of GPUs to use (auto-detect if None)
|
||||
skip_long_prompts: Skip prompts exceeding max_model_len instead of failing
|
||||
max_samples: Maximum number of samples to process (None for all)
|
||||
hf_token: Hugging Face authentication token
|
||||
"""
|
||||
generation_start_time = datetime.now().isoformat()
|
||||
|
||||
# GPU check and configuration
|
||||
num_gpus = check_gpu_availability()
|
||||
if tensor_parallel_size is None:
|
||||
tensor_parallel_size = num_gpus
|
||||
logger.info(
|
||||
f"Auto-detected {num_gpus} GPU(s), using tensor_parallel_size={tensor_parallel_size}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Using specified tensor_parallel_size={tensor_parallel_size}")
|
||||
if tensor_parallel_size > num_gpus:
|
||||
logger.warning(
|
||||
f"Requested {tensor_parallel_size} GPUs but only {num_gpus} available"
|
||||
)
|
||||
|
||||
# Authentication - try multiple methods
|
||||
HF_TOKEN = hf_token or os.environ.get("HF_TOKEN") or get_token()
|
||||
|
||||
if not HF_TOKEN:
|
||||
logger.error("No HuggingFace token found. Please provide token via:")
|
||||
logger.error(" 1. --hf-token argument")
|
||||
logger.error(" 2. HF_TOKEN environment variable")
|
||||
logger.error(" 3. Run 'hf auth login' or use login() in Python")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("HuggingFace token found, authenticating...")
|
||||
login(token=HF_TOKEN)
|
||||
|
||||
# Initialize vLLM
|
||||
logger.info(f"Loading model: {model_id}")
|
||||
vllm_kwargs = {
|
||||
"model": model_id,
|
||||
"tensor_parallel_size": tensor_parallel_size,
|
||||
"gpu_memory_utilization": gpu_memory_utilization,
|
||||
}
|
||||
if max_model_len is not None:
|
||||
vllm_kwargs["max_model_len"] = max_model_len
|
||||
logger.info(f"Using max_model_len={max_model_len}")
|
||||
|
||||
llm = LLM(**vllm_kwargs)
|
||||
|
||||
# Load tokenizer for chat template
|
||||
logger.info("Loading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Create sampling parameters
|
||||
sampling_params = SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
max_tokens=max_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {src_dataset_hub_id}")
|
||||
dataset = load_dataset(src_dataset_hub_id, split="train")
|
||||
|
||||
# Apply max_samples if specified
|
||||
if max_samples is not None and max_samples < len(dataset):
|
||||
logger.info(f"Limiting dataset to {max_samples} samples")
|
||||
dataset = dataset.select(range(max_samples))
|
||||
|
||||
total_examples = len(dataset)
|
||||
logger.info(f"Dataset loaded with {total_examples:,} examples")
|
||||
|
||||
# Determine which column to use and validate
|
||||
if prompt_column:
|
||||
# Use prompt column mode
|
||||
if prompt_column not in dataset.column_names:
|
||||
logger.error(
|
||||
f"Column '{prompt_column}' not found. Available columns: {dataset.column_names}"
|
||||
)
|
||||
sys.exit(1)
|
||||
logger.info(f"Using prompt column mode with column: '{prompt_column}'")
|
||||
use_messages = False
|
||||
else:
|
||||
# Use messages column mode
|
||||
if messages_column not in dataset.column_names:
|
||||
logger.error(
|
||||
f"Column '{messages_column}' not found. Available columns: {dataset.column_names}"
|
||||
)
|
||||
sys.exit(1)
|
||||
logger.info(f"Using messages column mode with column: '{messages_column}'")
|
||||
use_messages = True
|
||||
|
||||
# Get effective max length for filtering
|
||||
if max_model_len is not None:
|
||||
effective_max_len = max_model_len
|
||||
else:
|
||||
# Get model's default max length
|
||||
effective_max_len = llm.llm_engine.model_config.max_model_len
|
||||
logger.info(f"Using effective max model length: {effective_max_len}")
|
||||
|
||||
# Process messages and apply chat template
|
||||
logger.info("Preparing prompts...")
|
||||
all_prompts = []
|
||||
valid_prompts = []
|
||||
valid_indices = []
|
||||
skipped_info = []
|
||||
|
||||
for i, example in enumerate(tqdm(dataset, desc="Processing prompts")):
|
||||
if use_messages:
|
||||
# Messages mode: use existing chat messages
|
||||
messages = example[messages_column]
|
||||
# Apply chat template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
else:
|
||||
# Prompt mode: convert plain text to messages format
|
||||
user_prompt = example[prompt_column]
|
||||
messages = [{"role": "user", "content": user_prompt}]
|
||||
# Apply chat template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
all_prompts.append(prompt)
|
||||
|
||||
# Count tokens if filtering is enabled
|
||||
if skip_long_prompts:
|
||||
tokens = tokenizer.encode(prompt)
|
||||
if len(tokens) <= effective_max_len:
|
||||
valid_prompts.append(prompt)
|
||||
valid_indices.append(i)
|
||||
else:
|
||||
skipped_info.append((i, len(tokens)))
|
||||
else:
|
||||
valid_prompts.append(prompt)
|
||||
valid_indices.append(i)
|
||||
|
||||
# Log filtering results
|
||||
if skip_long_prompts and skipped_info:
|
||||
logger.warning(
|
||||
f"Skipped {len(skipped_info)} prompts that exceed max_model_len ({effective_max_len} tokens)"
|
||||
)
|
||||
logger.info("Skipped prompt details (first 10):")
|
||||
for idx, (prompt_idx, token_count) in enumerate(skipped_info[:10]):
|
||||
logger.info(
|
||||
f" - Example {prompt_idx}: {token_count} tokens (exceeds by {token_count - effective_max_len})"
|
||||
)
|
||||
if len(skipped_info) > 10:
|
||||
logger.info(f" ... and {len(skipped_info) - 10} more")
|
||||
|
||||
skip_percentage = (len(skipped_info) / total_examples) * 100
|
||||
if skip_percentage > 10:
|
||||
logger.warning(f"WARNING: {skip_percentage:.1f}% of prompts were skipped!")
|
||||
|
||||
if not valid_prompts:
|
||||
logger.error("No valid prompts to process after filtering!")
|
||||
sys.exit(1)
|
||||
|
||||
# Generate responses - vLLM handles batching internally
|
||||
logger.info(f"Starting generation for {len(valid_prompts):,} valid prompts...")
|
||||
logger.info("vLLM will handle batching and scheduling automatically")
|
||||
|
||||
outputs = llm.generate(valid_prompts, sampling_params)
|
||||
|
||||
# Extract generated text and create full response list
|
||||
logger.info("Extracting generated responses...")
|
||||
responses = [""] * total_examples # Initialize with empty strings
|
||||
|
||||
for idx, output in enumerate(outputs):
|
||||
original_idx = valid_indices[idx]
|
||||
response = output.outputs[0].text.strip()
|
||||
responses[original_idx] = response
|
||||
|
||||
# Add responses to dataset
|
||||
logger.info("Adding responses to dataset...")
|
||||
dataset = dataset.add_column(output_column, responses)
|
||||
|
||||
# Create dataset card
|
||||
logger.info("Creating dataset card...")
|
||||
card_content = create_dataset_card(
|
||||
source_dataset=src_dataset_hub_id,
|
||||
model_id=model_id,
|
||||
messages_column=messages_column,
|
||||
prompt_column=prompt_column,
|
||||
sampling_params=sampling_params,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
num_examples=total_examples,
|
||||
generation_time=generation_start_time,
|
||||
num_skipped=len(skipped_info) if skip_long_prompts else 0,
|
||||
max_model_len_used=effective_max_len if skip_long_prompts else None,
|
||||
)
|
||||
|
||||
# Push dataset to hub
|
||||
logger.info(f"Pushing dataset to: {output_dataset_hub_id}")
|
||||
dataset.push_to_hub(output_dataset_hub_id, token=HF_TOKEN)
|
||||
|
||||
# Push dataset card
|
||||
card = DatasetCard(card_content)
|
||||
card.push_to_hub(output_dataset_hub_id, token=HF_TOKEN)
|
||||
|
||||
logger.info("✅ Generation complete!")
|
||||
logger.info(
|
||||
f"Dataset available at: https://huggingface.co/datasets/{output_dataset_hub_id}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate responses for dataset prompts using vLLM",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Basic usage with default Qwen model
|
||||
uv run generate-responses.py input-dataset output-dataset
|
||||
|
||||
# With custom model and parameters
|
||||
uv run generate-responses.py input-dataset output-dataset \\
|
||||
--model-id meta-llama/Llama-3.1-8B-Instruct \\
|
||||
--temperature 0.9 \\
|
||||
--max-tokens 2048
|
||||
|
||||
# Force specific GPU configuration
|
||||
uv run generate-responses.py input-dataset output-dataset \\
|
||||
--tensor-parallel-size 2 \\
|
||||
--gpu-memory-utilization 0.95
|
||||
|
||||
# Using environment variable for token
|
||||
HF_TOKEN=hf_xxx uv run generate-responses.py input-dataset output-dataset
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"src_dataset_hub_id",
|
||||
help="Input dataset on Hugging Face Hub (e.g., username/dataset-name)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_dataset_hub_id", help="Output dataset name on Hugging Face Hub"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
type=str,
|
||||
default="Qwen/Qwen3-30B-A3B-Instruct-2507",
|
||||
help="Model to use for generation (default: Qwen3-30B-A3B-Instruct-2507)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--messages-column",
|
||||
type=str,
|
||||
default="messages",
|
||||
help="Column containing chat messages (default: messages)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-column",
|
||||
type=str,
|
||||
help="Column containing plain text prompts (alternative to --messages-column)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-column",
|
||||
type=str,
|
||||
default="response",
|
||||
help="Column name for generated responses (default: response)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-samples",
|
||||
type=int,
|
||||
help="Maximum number of samples to process (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="Sampling temperature (default: 0.7)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="Top-p sampling parameter (default: 0.8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Top-k sampling parameter (default: 20)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-p",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Minimum probability threshold (default: 0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=16384,
|
||||
help="Maximum tokens to generate (default: 16384)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repetition-penalty",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Repetition penalty (default: 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.90,
|
||||
help="GPU memory utilization factor (default: 0.90)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
help="Maximum model context length (default: model's default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tensor-parallel-size",
|
||||
type=int,
|
||||
help="Number of GPUs to use (default: auto-detect)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-token",
|
||||
type=str,
|
||||
help="Hugging Face token (can also use HF_TOKEN env var)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-long-prompts",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Skip prompts that exceed max_model_len instead of failing (default: True)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-skip-long-prompts",
|
||||
dest="skip_long_prompts",
|
||||
action="store_false",
|
||||
help="Fail on prompts that exceed max_model_len",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(
|
||||
src_dataset_hub_id=args.src_dataset_hub_id,
|
||||
output_dataset_hub_id=args.output_dataset_hub_id,
|
||||
model_id=args.model_id,
|
||||
messages_column=args.messages_column,
|
||||
prompt_column=args.prompt_column,
|
||||
output_column=args.output_column,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
min_p=args.min_p,
|
||||
max_tokens=args.max_tokens,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
max_model_len=args.max_model_len,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
skip_long_prompts=args.skip_long_prompts,
|
||||
max_samples=args.max_samples,
|
||||
hf_token=args.hf_token,
|
||||
)
|
||||
else:
|
||||
# Show HF Jobs example when run without arguments
|
||||
print("""
|
||||
vLLM Response Generation Script
|
||||
==============================
|
||||
|
||||
This script requires arguments. For usage information:
|
||||
uv run generate-responses.py --help
|
||||
|
||||
Example HF Jobs command with multi-GPU:
|
||||
# If you're logged in with hf auth, token will be auto-detected
|
||||
hf jobs uv run \\
|
||||
--flavor l4x4 \\
|
||||
https://huggingface.co/datasets/uv-scripts/vllm/raw/main/generate-responses.py \\
|
||||
username/input-dataset \\
|
||||
username/output-dataset \\
|
||||
--messages-column messages \\
|
||||
--model-id Qwen/Qwen3-30B-A3B-Instruct-2507 \\
|
||||
--temperature 0.7 \\
|
||||
--max-tokens 16384
|
||||
""")
|
||||
@@ -1,9 +1,9 @@
|
||||
---
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/huggingface-llm-trainer"
|
||||
name: hugging-face-model-trainer
|
||||
description: "Train language models using TRL (Transformer Reinforcement Learning) on fully managed Hugging Face infrastructure. No local GPU setup required—models train on cloud GPUs and results are automatically saved to the Hugging Face Hub."
|
||||
description: Train or fine-tune TRL language models on Hugging Face Jobs, including SFT, DPO, GRPO, and GGUF export.
|
||||
license: Complete terms in LICENSE.txt
|
||||
risk: unknown
|
||||
source: community
|
||||
---
|
||||
|
||||
# TRL Training on Hugging Face Jobs
|
||||
@@ -60,11 +60,12 @@ When assisting with training jobs:
|
||||
|
||||
4. **Use example scripts as templates** - Reference `scripts/train_sft_example.py`, `scripts/train_dpo_example.py`, etc. as starting points.
|
||||
|
||||
## Local Script Dependencies
|
||||
## Local Script Execution
|
||||
|
||||
To run scripts locally (like `estimate_cost.py`), install dependencies:
|
||||
Repository scripts use PEP 723 inline dependencies. Run them with `uv run`:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
uv run scripts/estimate_cost.py --help
|
||||
uv run scripts/dataset_inspector.py --help
|
||||
```
|
||||
|
||||
## Prerequisites Checklist
|
||||
@@ -240,8 +241,8 @@ hf_jobs("uv", {"script": "https://gist.githubusercontent.com/user/id/raw/train.p
|
||||
|
||||
**To use local scripts:** Upload to HF Hub first:
|
||||
```bash
|
||||
huggingface-cli repo create my-training-scripts --type model
|
||||
huggingface-cli upload my-training-scripts ./train.py train.py
|
||||
hf repos create my-training-scripts --type model
|
||||
hf upload my-training-scripts ./train.py train.py
|
||||
# Use: https://huggingface.co/USERNAME/my-training-scripts/resolve/main/train.py
|
||||
```
|
||||
|
||||
@@ -331,13 +332,10 @@ hf jobs cancel <job-id> # Cancel a job
|
||||
The `trl-jobs` package provides optimized defaults and one-liner training.
|
||||
|
||||
```bash
|
||||
# Install
|
||||
pip install trl-jobs
|
||||
|
||||
# Train with SFT (simplest possible)
|
||||
trl-jobs sft \
|
||||
uvx trl-jobs sft \
|
||||
--model_name Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name trl-lib/Capybara
|
||||
|
||||
```
|
||||
|
||||
**Benefits:** Pre-configured settings, automatic Trackio integration, automatic Hub push, one-line commands
|
||||
@@ -685,6 +683,7 @@ Add to PEP 723 header:
|
||||
- `references/hardware_guide.md` - Hardware specs and selection
|
||||
- `references/hub_saving.md` - Hub authentication troubleshooting
|
||||
- `references/troubleshooting.md` - Common issues and solutions
|
||||
- `references/local_training_macos.md` - Local training on macOS
|
||||
|
||||
### Scripts (In This Skill)
|
||||
- `scripts/train_sft_example.py` - Production SFT template
|
||||
|
||||
@@ -0,0 +1,296 @@
|
||||
# GGUF Conversion Guide
|
||||
|
||||
After training models with TRL on Hugging Face Jobs, convert them to **GGUF format** for use with llama.cpp, Ollama, LM Studio, and other local inference tools.
|
||||
|
||||
**This guide provides production-ready, tested code based on successful conversions.** All critical dependencies and build steps are included.
|
||||
|
||||
## What is GGUF?
|
||||
|
||||
**GGUF** (GPT-Generated Unified Format):
|
||||
- Optimized format for CPU/GPU inference with llama.cpp
|
||||
- Supports quantization (4-bit, 5-bit, 8-bit) to reduce model size
|
||||
- Compatible with: Ollama, LM Studio, Jan, GPT4All, llama.cpp
|
||||
- Typically 2-8GB for 7B models (vs 14GB unquantized)
|
||||
|
||||
## When to Convert to GGUF
|
||||
|
||||
**Convert when:**
|
||||
- Running models locally with Ollama or LM Studio
|
||||
- Using CPU-optimized inference
|
||||
- Reducing model size with quantization
|
||||
- Deploying to edge devices
|
||||
- Sharing models for local-first use
|
||||
|
||||
## Critical Success Factors
|
||||
|
||||
Based on production testing, these are **essential** for reliable conversion:
|
||||
|
||||
### 1. ✅ Install Build Tools FIRST
|
||||
**Before cloning llama.cpp**, install build dependencies:
|
||||
```python
|
||||
subprocess.run(["apt-get", "update", "-qq"], check=True, capture_output=True)
|
||||
subprocess.run(["apt-get", "install", "-y", "-qq", "build-essential", "cmake"], check=True, capture_output=True)
|
||||
```
|
||||
|
||||
**Why:** The quantization tool requires gcc and cmake. Installing after cloning doesn't help.
|
||||
|
||||
### 2. ✅ Use CMake (Not Make)
|
||||
**Build the quantize tool with CMake:**
|
||||
```python
|
||||
# Create build directory
|
||||
os.makedirs("/tmp/llama.cpp/build", exist_ok=True)
|
||||
|
||||
# Configure
|
||||
subprocess.run([
|
||||
"cmake", "-B", "/tmp/llama.cpp/build", "-S", "/tmp/llama.cpp",
|
||||
"-DGGML_CUDA=OFF" # Faster build, CUDA not needed for quantization
|
||||
], check=True, capture_output=True, text=True)
|
||||
|
||||
# Build
|
||||
subprocess.run([
|
||||
"cmake", "--build", "/tmp/llama.cpp/build",
|
||||
"--target", "llama-quantize", "-j", "4"
|
||||
], check=True, capture_output=True, text=True)
|
||||
|
||||
# Binary path
|
||||
quantize_bin = "/tmp/llama.cpp/build/bin/llama-quantize"
|
||||
```
|
||||
|
||||
**Why:** CMake is more reliable than `make` and produces consistent binary paths.
|
||||
|
||||
### 3. ✅ Include All Dependencies
|
||||
**PEP 723 header must include:**
|
||||
```python
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers>=4.36.0",
|
||||
# "peft>=0.7.0",
|
||||
# "torch>=2.0.0",
|
||||
# "accelerate>=0.24.0",
|
||||
# "huggingface_hub>=0.20.0",
|
||||
# "sentencepiece>=0.1.99", # Required for tokenizer
|
||||
# "protobuf>=3.20.0", # Required for tokenizer
|
||||
# "numpy",
|
||||
# "gguf",
|
||||
# ]
|
||||
# ///
|
||||
```
|
||||
|
||||
**Why:** `sentencepiece` and `protobuf` are critical for tokenizer conversion. Missing them causes silent failures.
|
||||
|
||||
### 4. ✅ Verify Names Before Use
|
||||
**Always verify repos exist:**
|
||||
```python
|
||||
# Before submitting job, verify:
|
||||
hub_repo_details([ADAPTER_MODEL], repo_type="model")
|
||||
hub_repo_details([BASE_MODEL], repo_type="model")
|
||||
```
|
||||
|
||||
**Why:** Non-existent dataset/model names cause job failures that could be caught in seconds.
|
||||
|
||||
## Complete Conversion Script
|
||||
|
||||
See `scripts/convert_to_gguf.py` for the complete, production-ready script.
|
||||
|
||||
**Key features:**
|
||||
- ✅ All dependencies in PEP 723 header
|
||||
- ✅ Build tools installed automatically
|
||||
- ✅ CMake build process (reliable)
|
||||
- ✅ Comprehensive error handling
|
||||
- ✅ Environment variable configuration
|
||||
- ✅ Automatic README generation
|
||||
|
||||
## Quick Conversion Job
|
||||
|
||||
```python
|
||||
# Before submitting: VERIFY MODELS EXIST
|
||||
hub_repo_details(["username/my-finetuned-model"], repo_type="model")
|
||||
hub_repo_details(["Qwen/Qwen2.5-0.5B"], repo_type="model")
|
||||
|
||||
# Submit conversion job
|
||||
hf_jobs("uv", {
|
||||
"script": open("trl/scripts/convert_to_gguf.py").read(), # Or inline the script
|
||||
"flavor": "a10g-large",
|
||||
"timeout": "45m",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"},
|
||||
"env": {
|
||||
"ADAPTER_MODEL": "username/my-finetuned-model",
|
||||
"BASE_MODEL": "Qwen/Qwen2.5-0.5B",
|
||||
"OUTPUT_REPO": "username/my-model-gguf",
|
||||
"HF_USERNAME": "username" # Optional, for README
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
## Conversion Process
|
||||
|
||||
The script performs these steps:
|
||||
|
||||
1. **Load and Merge** - Load base model and LoRA adapter, merge them
|
||||
2. **Install Build Tools** - Install gcc, cmake (CRITICAL: before cloning llama.cpp)
|
||||
3. **Setup llama.cpp** - Clone repo, install Python dependencies
|
||||
4. **Convert to GGUF** - Create FP16 GGUF using llama.cpp converter
|
||||
5. **Build Quantize Tool** - Use CMake to build `llama-quantize`
|
||||
6. **Quantize** - Create Q4_K_M, Q5_K_M, Q8_0 versions
|
||||
7. **Upload** - Upload all versions + README to Hub
|
||||
|
||||
## Quantization Options
|
||||
|
||||
Common quantization formats (from smallest to largest):
|
||||
|
||||
| Format | Size | Quality | Use Case |
|
||||
|--------|------|---------|----------|
|
||||
| **Q4_K_M** | ~300MB | Good | **Recommended** - best balance of size/quality |
|
||||
| **Q5_K_M** | ~350MB | Better | Higher quality, slightly larger |
|
||||
| **Q8_0** | ~500MB | Very High | Near-original quality |
|
||||
| **F16** | ~1GB | Original | Full precision, largest file |
|
||||
|
||||
**Recommendation:** Create Q4_K_M, Q5_K_M, and Q8_0 versions to give users options.
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
**For conversion:**
|
||||
- Small models (<1B): CPU-basic works, but slow
|
||||
- Medium models (1-7B): a10g-large recommended
|
||||
- Large models (7B+): a10g-large or a100-large
|
||||
|
||||
**Time estimates:**
|
||||
- 0.5B model: ~15-25 minutes on A10G
|
||||
- 3B model: ~30-45 minutes on A10G
|
||||
- 7B model: ~45-60 minutes on A10G
|
||||
|
||||
## Using GGUF Models
|
||||
|
||||
**GGUF models work on both CPU and GPU.** They're optimized for CPU inference but can also leverage GPU acceleration when available.
|
||||
|
||||
### With Ollama (auto-detects GPU)
|
||||
```bash
|
||||
# Download GGUF
|
||||
hf download username/my-model-gguf model-q4_k_m.gguf
|
||||
|
||||
# Create Modelfile
|
||||
echo "FROM ./model-q4_k_m.gguf" > Modelfile
|
||||
|
||||
# Create and run (uses GPU automatically if available)
|
||||
ollama create my-model -f Modelfile
|
||||
ollama run my-model
|
||||
```
|
||||
|
||||
### With llama.cpp
|
||||
```bash
|
||||
# CPU only
|
||||
./llama-cli -m model-q4_k_m.gguf -p "Your prompt"
|
||||
|
||||
# With GPU acceleration (offload 32 layers to GPU)
|
||||
./llama-cli -m model-q4_k_m.gguf -ngl 32 -p "Your prompt"
|
||||
```
|
||||
|
||||
### With LM Studio
|
||||
1. Download the `.gguf` file
|
||||
2. Import into LM Studio
|
||||
3. Start chatting
|
||||
|
||||
## Best Practices
|
||||
|
||||
### ✅ DO:
|
||||
1. **Verify repos exist** before submitting jobs (use `hub_repo_details`)
|
||||
2. **Install build tools FIRST** before cloning llama.cpp
|
||||
3. **Use CMake** for building quantize tool (not make)
|
||||
4. **Include all dependencies** in PEP 723 header (especially sentencepiece, protobuf)
|
||||
5. **Create multiple quantizations** - Give users choice
|
||||
6. **Test on known models** before production use
|
||||
7. **Use A10G GPU** for faster conversion
|
||||
|
||||
### ❌ DON'T:
|
||||
1. **Assume repos exist** - Always verify with hub tools
|
||||
2. **Use make** instead of CMake - Less reliable
|
||||
3. **Remove dependencies** to "simplify" - They're all needed
|
||||
4. **Skip build tools** - Quantization will fail silently
|
||||
5. **Use default paths** - CMake puts binaries in build/bin/
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Out of memory during merge
|
||||
**Fix:**
|
||||
- Use larger GPU (a10g-large or a100-large)
|
||||
- Ensure `device_map="auto"` for automatic placement
|
||||
- Use `dtype=torch.float16` or `torch.bfloat16`
|
||||
|
||||
### Conversion fails with architecture error
|
||||
**Fix:**
|
||||
- Ensure llama.cpp supports the model architecture
|
||||
- Check for standard architecture (Qwen, Llama, Mistral, etc.)
|
||||
- Update llama.cpp to latest: `git clone --depth 1 https://github.com/ggerganov/llama.cpp.git`
|
||||
- Check llama.cpp documentation for model support
|
||||
|
||||
### Quantization fails
|
||||
**Fix:**
|
||||
- Verify build tools installed: `apt-get install build-essential cmake`
|
||||
- Use CMake (not make) to build quantize tool
|
||||
- Check binary path: `/tmp/llama.cpp/build/bin/llama-quantize`
|
||||
- Verify FP16 GGUF exists before quantizing
|
||||
|
||||
### Missing sentencepiece error
|
||||
**Fix:**
|
||||
- Add to PEP 723 header: `"sentencepiece>=0.1.99", "protobuf>=3.20.0"`
|
||||
- Don't remove dependencies to "simplify" - all are required
|
||||
|
||||
### Upload fails or times out
|
||||
**Fix:**
|
||||
- Large models (>2GB) need longer timeout: `"timeout": "1h"`
|
||||
- Upload quantized versions separately if needed
|
||||
- Check network/Hub status
|
||||
|
||||
## Lessons Learned
|
||||
|
||||
These are from production testing and real failures:
|
||||
|
||||
### 1. Always Verify Before Use
|
||||
**Lesson:** Don't assume repos/datasets exist. Check first.
|
||||
```python
|
||||
# BEFORE submitting job
|
||||
hub_repo_details(["trl-lib/argilla-dpo-mix-7k"], repo_type="dataset") # Would catch error
|
||||
```
|
||||
**Prevented failures:** Non-existent dataset names, typos in model names
|
||||
|
||||
### 2. Prioritize Reliability Over Performance
|
||||
**Lesson:** Default to what's most likely to succeed.
|
||||
- Use CMake (not make) - more reliable
|
||||
- Disable CUDA in build - faster, not needed
|
||||
- Include all dependencies - don't "simplify"
|
||||
|
||||
**Prevented failures:** Build failures, missing binaries
|
||||
|
||||
### 3. Create Atomic, Self-Contained Scripts
|
||||
**Lesson:** Don't remove dependencies or steps. Scripts should work as a unit.
|
||||
- All dependencies in PEP 723 header
|
||||
- All build steps included
|
||||
- Clear error messages
|
||||
|
||||
**Prevented failures:** Missing tokenizer libraries, build tool failures
|
||||
|
||||
## References
|
||||
|
||||
**In this skill:**
|
||||
- `scripts/convert_to_gguf.py` - Complete, production-ready script
|
||||
|
||||
**External:**
|
||||
- [llama.cpp Repository](https://github.com/ggerganov/llama.cpp)
|
||||
- [GGUF Specification](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md)
|
||||
- [Ollama Documentation](https://ollama.ai)
|
||||
- [LM Studio](https://lmstudio.ai)
|
||||
|
||||
## Summary
|
||||
|
||||
**Critical checklist for GGUF conversion:**
|
||||
- [ ] Verify adapter and base models exist on Hub
|
||||
- [ ] Use production script from `scripts/convert_to_gguf.py`
|
||||
- [ ] All dependencies in PEP 723 header (including sentencepiece, protobuf)
|
||||
- [ ] Build tools installed before cloning llama.cpp
|
||||
- [ ] CMake used for building quantize tool (not make)
|
||||
- [ ] Correct binary path: `/tmp/llama.cpp/build/bin/llama-quantize`
|
||||
- [ ] A10G GPU selected for reasonable conversion time
|
||||
- [ ] Timeout set to 45m minimum
|
||||
- [ ] HF_TOKEN in secrets for Hub upload
|
||||
|
||||
**The script in `scripts/convert_to_gguf.py` incorporates all these lessons and has been tested successfully in production.**
|
||||
@@ -0,0 +1,283 @@
|
||||
# Hardware Selection Guide
|
||||
|
||||
Choosing the right hardware (flavor) is critical for cost-effective training.
|
||||
|
||||
## Available Hardware
|
||||
|
||||
### CPU
|
||||
- `cpu-basic` - Basic CPU, testing only
|
||||
- `cpu-upgrade` - Enhanced CPU
|
||||
|
||||
**Use cases:** Dataset validation, preprocessing, testing scripts
|
||||
**Not recommended for training:** Too slow for any meaningful training
|
||||
|
||||
### GPU Options
|
||||
|
||||
| Flavor | GPU | Memory | Use Case | Cost/hour |
|
||||
|--------|-----|--------|----------|-----------|
|
||||
| `t4-small` | NVIDIA T4 | 16GB | <1B models, demos | ~$0.50-1 |
|
||||
| `t4-medium` | NVIDIA T4 | 16GB | 1-3B models, development | ~$1-2 |
|
||||
| `l4x1` | NVIDIA L4 | 24GB | 3-7B models, efficient training | ~$2-3 |
|
||||
| `l4x4` | 4x NVIDIA L4 | 96GB | Multi-GPU training | ~$8-12 |
|
||||
| `a10g-small` | NVIDIA A10G | 24GB | 3-7B models, production | ~$3-4 |
|
||||
| `a10g-large` | NVIDIA A10G | 24GB | 7-13B models | ~$4-6 |
|
||||
| `a10g-largex2` | 2x NVIDIA A10G | 48GB | Multi-GPU, large models | ~$8-12 |
|
||||
| `a10g-largex4` | 4x NVIDIA A10G | 96GB | Multi-GPU, very large models | ~$16-24 |
|
||||
| `a100-large` | NVIDIA A100 | 40GB | 13B+ models, fast training | ~$8-12 |
|
||||
|
||||
### TPU Options
|
||||
|
||||
| Flavor | Type | Use Case |
|
||||
|--------|------|----------|
|
||||
| `v5e-1x1` | TPU v5e | Small TPU workloads |
|
||||
| `v5e-2x2` | 4x TPU v5e | Medium TPU workloads |
|
||||
| `v5e-2x4` | 8x TPU v5e | Large TPU workloads |
|
||||
|
||||
**Note:** TPUs require TPU-optimized code. Most TRL training uses GPUs.
|
||||
|
||||
## Selection Guidelines
|
||||
|
||||
### By Model Size
|
||||
|
||||
**Tiny Models (<1B parameters)**
|
||||
- **Recommended:** `t4-small`
|
||||
- **Example:** Qwen2.5-0.5B, TinyLlama
|
||||
- **Batch size:** 4-8
|
||||
- **Training time:** 1-2 hours for 1K examples
|
||||
|
||||
**Small Models (1-3B parameters)**
|
||||
- **Recommended:** `t4-medium` or `a10g-small`
|
||||
- **Example:** Qwen2.5-1.5B, Phi-2
|
||||
- **Batch size:** 2-4
|
||||
- **Training time:** 2-4 hours for 10K examples
|
||||
|
||||
**Medium Models (3-7B parameters)**
|
||||
- **Recommended:** `a10g-small` or `a10g-large`
|
||||
- **Example:** Qwen2.5-7B, Mistral-7B
|
||||
- **Batch size:** 1-2 (or LoRA with 4-8)
|
||||
- **Training time:** 4-8 hours for 10K examples
|
||||
|
||||
**Large Models (7-13B parameters)**
|
||||
- **Recommended:** `a10g-large` or `a100-large`
|
||||
- **Example:** Llama-3-8B, Mixtral-8x7B (with LoRA)
|
||||
- **Batch size:** 1 (full fine-tuning) or 2-4 (LoRA)
|
||||
- **Training time:** 6-12 hours for 10K examples
|
||||
- **Note:** Always use LoRA/PEFT
|
||||
|
||||
**Very Large Models (13B+ parameters)**
|
||||
- **Recommended:** `a100-large` with LoRA
|
||||
- **Example:** Llama-3-13B, Llama-3-70B (LoRA only)
|
||||
- **Batch size:** 1-2 with LoRA
|
||||
- **Training time:** 8-24 hours for 10K examples
|
||||
- **Note:** Full fine-tuning not feasible, use LoRA/PEFT
|
||||
|
||||
### By Budget
|
||||
|
||||
**Minimal Budget (<$5 total)**
|
||||
- Use `t4-small`
|
||||
- Train on subset of data (100-500 examples)
|
||||
- Limit to 1-2 epochs
|
||||
- Use small model (<1B)
|
||||
|
||||
**Small Budget ($5-20)**
|
||||
- Use `t4-medium` or `a10g-small`
|
||||
- Train on 1K-5K examples
|
||||
- 2-3 epochs
|
||||
- Model up to 3B parameters
|
||||
|
||||
**Medium Budget ($20-50)**
|
||||
- Use `a10g-small` or `a10g-large`
|
||||
- Train on 5K-20K examples
|
||||
- 3-5 epochs
|
||||
- Model up to 7B parameters
|
||||
|
||||
**Large Budget ($50-200)**
|
||||
- Use `a10g-large` or `a100-large`
|
||||
- Full dataset training
|
||||
- Multiple epochs
|
||||
- Model up to 13B parameters with LoRA
|
||||
|
||||
### By Training Type
|
||||
|
||||
**Quick Demo/Experiment**
|
||||
- `t4-small`
|
||||
- 50-100 examples
|
||||
- 5-10 steps
|
||||
- ~10-15 minutes
|
||||
|
||||
**Development/Iteration**
|
||||
- `t4-medium` or `a10g-small`
|
||||
- 1K examples
|
||||
- 1 epoch
|
||||
- ~30-60 minutes
|
||||
|
||||
**Production Training**
|
||||
- `a10g-large` or `a100-large`
|
||||
- Full dataset
|
||||
- 3-5 epochs
|
||||
- 4-12 hours
|
||||
|
||||
**Research/Experimentation**
|
||||
- `a100-large`
|
||||
- Multiple runs
|
||||
- Various hyperparameters
|
||||
- Budget for 20-50 hours
|
||||
|
||||
## Memory Considerations
|
||||
|
||||
### Estimating Memory Requirements
|
||||
|
||||
**Full fine-tuning:**
|
||||
```
|
||||
Memory (GB) ≈ (Model params in billions) × 20
|
||||
```
|
||||
|
||||
**LoRA fine-tuning:**
|
||||
```
|
||||
Memory (GB) ≈ (Model params in billions) × 4
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
- Qwen2.5-0.5B full: ~10GB ✅ fits t4-small
|
||||
- Qwen2.5-1.5B full: ~30GB ❌ exceeds most GPUs
|
||||
- Qwen2.5-1.5B LoRA: ~6GB ✅ fits t4-small
|
||||
- Qwen2.5-7B full: ~140GB ❌ not feasible
|
||||
- Qwen2.5-7B LoRA: ~28GB ✅ fits a10g-large
|
||||
|
||||
### Memory Optimization
|
||||
|
||||
If hitting memory limits:
|
||||
|
||||
1. **Use LoRA/PEFT**
|
||||
```python
|
||||
peft_config=LoraConfig(r=16, lora_alpha=32)
|
||||
```
|
||||
|
||||
2. **Reduce batch size**
|
||||
```python
|
||||
per_device_train_batch_size=1
|
||||
```
|
||||
|
||||
3. **Increase gradient accumulation**
|
||||
```python
|
||||
gradient_accumulation_steps=8 # Effective batch size = 1×8
|
||||
```
|
||||
|
||||
4. **Enable gradient checkpointing**
|
||||
```python
|
||||
gradient_checkpointing=True
|
||||
```
|
||||
|
||||
5. **Use mixed precision**
|
||||
```python
|
||||
bf16=True # or fp16=True
|
||||
```
|
||||
|
||||
6. **Upgrade to larger GPU**
|
||||
- t4 → a10g → a100
|
||||
|
||||
## Cost Estimation
|
||||
|
||||
### Formula
|
||||
|
||||
```
|
||||
Total Cost = (Hours of training) × (Cost per hour)
|
||||
```
|
||||
|
||||
### Example Calculations
|
||||
|
||||
**Quick demo:**
|
||||
- Hardware: t4-small ($0.75/hour)
|
||||
- Time: 15 minutes (0.25 hours)
|
||||
- Cost: $0.19
|
||||
|
||||
**Development training:**
|
||||
- Hardware: a10g-small ($3.50/hour)
|
||||
- Time: 2 hours
|
||||
- Cost: $7.00
|
||||
|
||||
**Production training:**
|
||||
- Hardware: a10g-large ($5/hour)
|
||||
- Time: 6 hours
|
||||
- Cost: $30.00
|
||||
|
||||
**Large model with LoRA:**
|
||||
- Hardware: a100-large ($10/hour)
|
||||
- Time: 8 hours
|
||||
- Cost: $80.00
|
||||
|
||||
### Cost Optimization Tips
|
||||
|
||||
1. **Start small:** Test on t4-small with subset
|
||||
2. **Use LoRA:** 4-5x cheaper than full fine-tuning
|
||||
3. **Optimize hyperparameters:** Fewer epochs if possible
|
||||
4. **Set appropriate timeout:** Don't waste compute on stalled jobs
|
||||
5. **Use checkpointing:** Resume if job fails
|
||||
6. **Monitor costs:** Check running jobs regularly
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
TRL automatically handles multi-GPU training with Accelerate when using multi-GPU flavors.
|
||||
|
||||
**Multi-GPU flavors:**
|
||||
- `l4x4` - 4x L4 GPUs
|
||||
- `a10g-largex2` - 2x A10G GPUs
|
||||
- `a10g-largex4` - 4x A10G GPUs
|
||||
|
||||
**When to use:**
|
||||
- Models >13B parameters
|
||||
- Need faster training (linear speedup)
|
||||
- Large datasets (>50K examples)
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "train.py",
|
||||
"flavor": "a10g-largex2", # 2 GPUs
|
||||
"timeout": "4h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
})
|
||||
```
|
||||
|
||||
No code changes needed—TRL/Accelerate handles distribution automatically.
|
||||
|
||||
## Choosing Between Options
|
||||
|
||||
### a10g vs a100
|
||||
|
||||
**Choose a10g when:**
|
||||
- Model <13B parameters
|
||||
- Budget conscious
|
||||
- Training time not critical
|
||||
|
||||
**Choose a100 when:**
|
||||
- Model 13B+ parameters
|
||||
- Need fastest training
|
||||
- Memory requirements high
|
||||
- Budget allows
|
||||
|
||||
### Single vs Multi-GPU
|
||||
|
||||
**Choose single GPU when:**
|
||||
- Model <7B parameters
|
||||
- Budget constrained
|
||||
- Simpler debugging
|
||||
|
||||
**Choose multi-GPU when:**
|
||||
- Model >13B parameters
|
||||
- Need faster training
|
||||
- Large batch sizes required
|
||||
- Cost-effective for large jobs
|
||||
|
||||
## Quick Reference
|
||||
|
||||
```python
|
||||
# Model size → Hardware selection
|
||||
HARDWARE_MAP = {
|
||||
"<1B": "t4-small",
|
||||
"1-3B": "a10g-small",
|
||||
"3-7B": "a10g-large",
|
||||
"7-13B": "a10g-large (LoRA) or a100-large",
|
||||
">13B": "a100-large (LoRA required)"
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,364 @@
|
||||
# Saving Training Results to Hugging Face Hub
|
||||
|
||||
**⚠️ CRITICAL:** Training environments are ephemeral. ALL results are lost when a job completes unless pushed to the Hub.
|
||||
|
||||
## Why Hub Push is Required
|
||||
|
||||
When running on Hugging Face Jobs:
|
||||
- Environment is temporary
|
||||
- All files deleted on job completion
|
||||
- No local disk persistence
|
||||
- Cannot access results after job ends
|
||||
|
||||
**Without Hub push, training is completely wasted.**
|
||||
|
||||
## Required Configuration
|
||||
|
||||
### 1. Training Configuration
|
||||
|
||||
In your SFTConfig or trainer config:
|
||||
|
||||
```python
|
||||
SFTConfig(
|
||||
push_to_hub=True, # Enable Hub push
|
||||
hub_model_id="username/model-name", # Target repository
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Job Configuration
|
||||
|
||||
When submitting the job:
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "train.py",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"} # Provide authentication
|
||||
})
|
||||
```
|
||||
|
||||
**The `$HF_TOKEN` placeholder is automatically replaced with your Hugging Face token.**
|
||||
|
||||
## Complete Example
|
||||
|
||||
```python
|
||||
# train.py
|
||||
# /// script
|
||||
# dependencies = ["trl"]
|
||||
# ///
|
||||
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
# Configure with Hub push
|
||||
config = SFTConfig(
|
||||
output_dir="my-model",
|
||||
num_train_epochs=3,
|
||||
|
||||
# ✅ CRITICAL: Hub push configuration
|
||||
push_to_hub=True,
|
||||
hub_model_id="myusername/my-trained-model",
|
||||
|
||||
# Optional: Push strategy
|
||||
push_to_hub_model_id="myusername/my-trained-model",
|
||||
push_to_hub_organization=None,
|
||||
push_to_hub_token=None, # Uses environment token
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset,
|
||||
args=config,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# ✅ Push final model
|
||||
trainer.push_to_hub()
|
||||
|
||||
print("✅ Model saved to: https://huggingface.co/myusername/my-trained-model")
|
||||
```
|
||||
|
||||
**Submit with authentication:**
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "train.py",
|
||||
"flavor": "a10g-large",
|
||||
"timeout": "2h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"} # ✅ Required!
|
||||
})
|
||||
```
|
||||
|
||||
## What Gets Saved
|
||||
|
||||
When `push_to_hub=True`:
|
||||
|
||||
1. **Model weights** - Final trained parameters
|
||||
2. **Tokenizer** - Associated tokenizer
|
||||
3. **Configuration** - Model config (config.json)
|
||||
4. **Training arguments** - Hyperparameters used
|
||||
5. **Model card** - Auto-generated documentation
|
||||
6. **Checkpoints** - If `save_strategy="steps"` enabled
|
||||
|
||||
## Checkpoint Saving
|
||||
|
||||
Save intermediate checkpoints during training:
|
||||
|
||||
```python
|
||||
SFTConfig(
|
||||
output_dir="my-model",
|
||||
push_to_hub=True,
|
||||
hub_model_id="username/my-model",
|
||||
|
||||
# Checkpoint configuration
|
||||
save_strategy="steps",
|
||||
save_steps=100, # Save every 100 steps
|
||||
save_total_limit=3, # Keep only last 3 checkpoints
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Resume training if job fails
|
||||
- Compare checkpoint performance
|
||||
- Use intermediate models
|
||||
|
||||
**Checkpoints are pushed to:** `username/my-model` (same repo)
|
||||
|
||||
## Authentication Methods
|
||||
|
||||
### Method 1: Automatic Token (Recommended)
|
||||
|
||||
```python
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
```
|
||||
|
||||
Uses your logged-in Hugging Face token automatically.
|
||||
|
||||
### Method 2: Explicit Token
|
||||
|
||||
```python
|
||||
"secrets": {"HF_TOKEN": "hf_abc123..."}
|
||||
```
|
||||
|
||||
Provide token explicitly (not recommended for security).
|
||||
|
||||
### Method 3: Environment Variable
|
||||
|
||||
```python
|
||||
"env": {"HF_TOKEN": "hf_abc123..."}
|
||||
```
|
||||
|
||||
Pass as regular environment variable (less secure than secrets).
|
||||
|
||||
**Always prefer Method 1** for security and convenience.
|
||||
|
||||
## Verification Checklist
|
||||
|
||||
Before submitting any training job, verify:
|
||||
|
||||
- [ ] `push_to_hub=True` in training config
|
||||
- [ ] `hub_model_id` is specified (format: `username/model-name`)
|
||||
- [ ] `secrets={"HF_TOKEN": "$HF_TOKEN"}` in job config
|
||||
- [ ] Repository name doesn't conflict with existing repos
|
||||
- [ ] You have write access to the target namespace
|
||||
|
||||
## Repository Setup
|
||||
|
||||
### Automatic Creation
|
||||
|
||||
If repository doesn't exist, it's created automatically when first pushing.
|
||||
|
||||
### Manual Creation
|
||||
|
||||
Create repository before training:
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(
|
||||
repo_id="username/model-name",
|
||||
repo_type="model",
|
||||
private=False, # or True for private repo
|
||||
)
|
||||
```
|
||||
|
||||
### Repository Naming
|
||||
|
||||
**Valid names:**
|
||||
- `username/my-model`
|
||||
- `username/model-name`
|
||||
- `organization/model-name`
|
||||
|
||||
**Invalid names:**
|
||||
- `model-name` (missing username)
|
||||
- `username/model name` (spaces not allowed)
|
||||
- `username/MODEL` (uppercase discouraged)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Error: 401 Unauthorized
|
||||
|
||||
**Cause:** HF_TOKEN not provided or invalid
|
||||
|
||||
**Solutions:**
|
||||
1. Verify `secrets={"HF_TOKEN": "$HF_TOKEN"}` in job config
|
||||
2. Check you're logged in: `hf auth whoami`
|
||||
3. Re-login: `hf auth login`
|
||||
|
||||
### Error: 403 Forbidden
|
||||
|
||||
**Cause:** No write access to repository
|
||||
|
||||
**Solutions:**
|
||||
1. Check repository namespace matches your username
|
||||
2. Verify you're a member of organization (if using org namespace)
|
||||
3. Check repository isn't private (if accessing org repo)
|
||||
|
||||
### Error: Repository not found
|
||||
|
||||
**Cause:** Repository doesn't exist and auto-creation failed
|
||||
|
||||
**Solutions:**
|
||||
1. Manually create repository first
|
||||
2. Check repository name format
|
||||
3. Verify namespace exists
|
||||
|
||||
### Error: Push failed during training
|
||||
|
||||
**Cause:** Network issues or Hub unavailable
|
||||
|
||||
**Solutions:**
|
||||
1. Training continues but final push fails
|
||||
2. Checkpoints may be saved
|
||||
3. Re-run push manually after job completes
|
||||
|
||||
### Issue: Model saved but not visible
|
||||
|
||||
**Possible causes:**
|
||||
1. Repository is private—check https://huggingface.co/username
|
||||
2. Wrong namespace—verify `hub_model_id` matches login
|
||||
3. Push still in progress—wait a few minutes
|
||||
|
||||
## Manual Push After Training
|
||||
|
||||
If training completes but push fails, push manually:
|
||||
|
||||
```python
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
# Load from local checkpoint
|
||||
model = AutoModel.from_pretrained("./output_dir")
|
||||
tokenizer = AutoTokenizer.from_pretrained("./output_dir")
|
||||
|
||||
# Push to Hub
|
||||
model.push_to_hub("username/model-name", token="hf_abc123...")
|
||||
tokenizer.push_to_hub("username/model-name", token="hf_abc123...")
|
||||
```
|
||||
|
||||
**Note:** Only possible if job hasn't completed (files still exist).
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always enable `push_to_hub=True`**
|
||||
2. **Use checkpoint saving** for long training runs
|
||||
3. **Verify Hub push** in logs before job completes
|
||||
4. **Set appropriate `save_total_limit`** to avoid excessive checkpoints
|
||||
5. **Use descriptive repo names** (e.g., `qwen-capybara-sft` not `model1`)
|
||||
6. **Add model card** with training details
|
||||
7. **Tag models** with relevant tags (e.g., `text-generation`, `fine-tuned`)
|
||||
|
||||
## Monitoring Push Progress
|
||||
|
||||
Check logs for push progress:
|
||||
|
||||
```python
|
||||
hf_jobs("logs", {"job_id": "your-job-id"})
|
||||
```
|
||||
|
||||
**Look for:**
|
||||
```
|
||||
Pushing model to username/model-name...
|
||||
Upload file pytorch_model.bin: 100%
|
||||
✅ Model pushed successfully
|
||||
```
|
||||
|
||||
## Example: Full Production Setup
|
||||
|
||||
```python
|
||||
# production_train.py
|
||||
# /// script
|
||||
# dependencies = ["trl>=0.12.0", "peft>=0.7.0"]
|
||||
# ///
|
||||
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
import os
|
||||
|
||||
# Verify token is available
|
||||
assert "HF_TOKEN" in os.environ, "HF_TOKEN not found in environment!"
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
print(f"✅ Dataset loaded: {len(dataset)} examples")
|
||||
|
||||
# Configure with comprehensive Hub settings
|
||||
config = SFTConfig(
|
||||
output_dir="qwen-capybara-sft",
|
||||
|
||||
# Hub configuration
|
||||
push_to_hub=True,
|
||||
hub_model_id="myusername/qwen-capybara-sft",
|
||||
hub_strategy="checkpoint", # Push checkpoints
|
||||
|
||||
# Checkpoint configuration
|
||||
save_strategy="steps",
|
||||
save_steps=100,
|
||||
save_total_limit=3,
|
||||
|
||||
# Training settings
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=4,
|
||||
|
||||
# Logging
|
||||
logging_steps=10,
|
||||
logging_first_step=True,
|
||||
)
|
||||
|
||||
# Train with LoRA
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset,
|
||||
args=config,
|
||||
peft_config=LoraConfig(r=16, lora_alpha=32),
|
||||
)
|
||||
|
||||
print("🚀 Starting training...")
|
||||
trainer.train()
|
||||
|
||||
print("💾 Pushing final model to Hub...")
|
||||
trainer.push_to_hub()
|
||||
|
||||
print("✅ Training complete!")
|
||||
print(f"Model available at: https://huggingface.co/myusername/qwen-capybara-sft")
|
||||
```
|
||||
|
||||
**Submit:**
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "production_train.py",
|
||||
"flavor": "a10g-large",
|
||||
"timeout": "6h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
})
|
||||
```
|
||||
|
||||
## Key Takeaway
|
||||
|
||||
**Without `push_to_hub=True` and `secrets={"HF_TOKEN": "$HF_TOKEN"}`, all training results are permanently lost.**
|
||||
|
||||
Always verify both are configured before submitting any training job.
|
||||
@@ -0,0 +1,231 @@
|
||||
# Local Training on macOS (Apple Silicon)
|
||||
|
||||
Run small LoRA fine-tuning jobs locally on Mac for smoke tests and quick iteration before submitting to HF Jobs.
|
||||
|
||||
## When to Use Local Mac vs HF Jobs
|
||||
|
||||
| Local Mac | HF Jobs / Cloud GPU |
|
||||
|-----------|-------------------|
|
||||
| Model ≤3B, text-only | Model 7B+ |
|
||||
| LoRA/PEFT only | QLoRA 4-bit (CUDA/bitsandbytes) |
|
||||
| Short context (≤1024) | Long context / full fine-tuning |
|
||||
| Smoke tests, dataset validation | Production runs, VLMs |
|
||||
|
||||
**Typical workflow:** local smoke test → HF Jobs with same config → export/quantize ([gguf_conversion.md](gguf_conversion.md))
|
||||
|
||||
## Recommended Defaults
|
||||
|
||||
| Setting | Value | Notes |
|
||||
|---------|-------|-------|
|
||||
| Model size | 0.5B–1.5B first run | Scale up after verifying |
|
||||
| Max seq length | 512–1024 | Lower = less memory |
|
||||
| Batch size | 1 | Scale via gradient accumulation |
|
||||
| Gradient accumulation | 8–16 | Effective batch = 8–16 |
|
||||
| LoRA rank (r) | 8–16 | alpha = 2×r |
|
||||
| Dtype | float32 | fp16 causes NaN on MPS; bf16 only on M1 Pro+ and M2/M3/M4 |
|
||||
|
||||
### Memory by hardware
|
||||
|
||||
| Unified RAM | Max Model Size |
|
||||
|-------------|---------------|
|
||||
| 16 GB | ~0.5B–1.5B |
|
||||
| 32 GB | ~1.5B–3B |
|
||||
| 64 GB | ~3B (short context) |
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
xcode-select --install
|
||||
python3 -m venv .venv && source .venv/bin/activate
|
||||
pip install -U "torch>=2.2" "transformers>=4.40" "trl>=0.12" "peft>=0.10" \
|
||||
datasets accelerate safetensors huggingface_hub
|
||||
```
|
||||
|
||||
Verify MPS:
|
||||
```bash
|
||||
python -c "import torch; print(torch.__version__, '| MPS:', torch.backends.mps.is_available())"
|
||||
```
|
||||
|
||||
Optional — configure Accelerate for local Mac (no distributed, no mixed precision, MPS device):
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
## Training Script
|
||||
|
||||
<details>
|
||||
<summary><strong>train_lora_sft.py</strong></summary>
|
||||
|
||||
```python
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
|
||||
from peft import LoraConfig
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
set_seed(42)
|
||||
|
||||
@dataclass
|
||||
class Cfg:
|
||||
model_id: str = os.environ.get("MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset_id: str = os.environ.get("DATASET_ID", "HuggingFaceH4/ultrachat_200k")
|
||||
dataset_split: str = os.environ.get("DATASET_SPLIT", "train_sft[:500]")
|
||||
data_files: Optional[str] = os.environ.get("DATA_FILES", None)
|
||||
text_field: str = os.environ.get("TEXT_FIELD", "")
|
||||
messages_field: str = os.environ.get("MESSAGES_FIELD", "messages")
|
||||
out_dir: str = os.environ.get("OUT_DIR", "outputs/local-lora")
|
||||
max_seq_length: int = int(os.environ.get("MAX_SEQ_LENGTH", "512"))
|
||||
max_steps: int = int(os.environ.get("MAX_STEPS", "-1"))
|
||||
|
||||
cfg = Cfg()
|
||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(cfg.model_id, use_fast=True)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "right"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(cfg.model_id, torch_dtype=torch.float32)
|
||||
model.to(device)
|
||||
model.config.use_cache = False
|
||||
|
||||
if cfg.data_files:
|
||||
ds = load_dataset("json", data_files=cfg.data_files, split="train")
|
||||
else:
|
||||
ds = load_dataset(cfg.dataset_id, split=cfg.dataset_split)
|
||||
|
||||
def format_example(ex):
|
||||
if cfg.text_field and isinstance(ex.get(cfg.text_field), str):
|
||||
ex["text"] = ex[cfg.text_field]
|
||||
return ex
|
||||
msgs = ex.get(cfg.messages_field)
|
||||
if isinstance(msgs, list):
|
||||
if hasattr(tokenizer, "apply_chat_template"):
|
||||
try:
|
||||
ex["text"] = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
|
||||
return ex
|
||||
except Exception:
|
||||
pass
|
||||
ex["text"] = "\n".join([str(m) for m in msgs])
|
||||
return ex
|
||||
ex["text"] = str(ex)
|
||||
return ex
|
||||
|
||||
ds = ds.map(format_example)
|
||||
ds = ds.remove_columns([c for c in ds.column_names if c != "text"])
|
||||
|
||||
lora = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
|
||||
task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj"])
|
||||
|
||||
sft_kwargs = dict(
|
||||
output_dir=cfg.out_dir, per_device_train_batch_size=1, gradient_accumulation_steps=8,
|
||||
learning_rate=2e-4, logging_steps=10, save_steps=200, save_total_limit=2,
|
||||
gradient_checkpointing=True, report_to="none", fp16=False, bf16=False,
|
||||
max_seq_length=cfg.max_seq_length, dataset_text_field="text",
|
||||
)
|
||||
if cfg.max_steps > 0:
|
||||
sft_kwargs["max_steps"] = cfg.max_steps
|
||||
else:
|
||||
sft_kwargs["num_train_epochs"] = 1
|
||||
|
||||
trainer = SFTTrainer(model=model, train_dataset=ds, peft_config=lora,
|
||||
args=SFTConfig(**sft_kwargs), processing_class=tokenizer)
|
||||
trainer.train()
|
||||
trainer.save_model(cfg.out_dir)
|
||||
print(f"✅ Saved to: {cfg.out_dir}")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Run
|
||||
|
||||
```bash
|
||||
python train_lora_sft.py
|
||||
```
|
||||
|
||||
**Env overrides:**
|
||||
|
||||
```bash
|
||||
MODEL_ID="Qwen/Qwen2.5-1.5B-Instruct" python train_lora_sft.py # different model
|
||||
MAX_STEPS=50 python train_lora_sft.py # quick 50-step test
|
||||
DATA_FILES="my_data.jsonl" python train_lora_sft.py # local JSONL file
|
||||
PYTORCH_ENABLE_MPS_FALLBACK=1 python train_lora_sft.py # MPS op fallback to CPU
|
||||
PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 python train_lora_sft.py # disable MPS memory limit (use with caution)
|
||||
```
|
||||
|
||||
**Local JSONL format** — chat messages or plain text:
|
||||
```jsonl
|
||||
{"messages": [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi!"}]}
|
||||
```
|
||||
```jsonl
|
||||
{"text": "User: Hello\nAssistant: Hi!"}
|
||||
```
|
||||
For plain text: `DATA_FILES="file.jsonl" TEXT_FIELD="text" MESSAGES_FIELD="" python train_lora_sft.py`
|
||||
|
||||
### Verify Success
|
||||
|
||||
- Loss decreases over steps
|
||||
- `outputs/local-lora/` contains `adapter_config.json` + `*.safetensors`
|
||||
|
||||
## Quick Evaluation
|
||||
|
||||
<details>
|
||||
<summary><strong>eval_generate.py</strong></summary>
|
||||
|
||||
```python
|
||||
import os, torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from peft import PeftModel
|
||||
|
||||
BASE = os.environ.get("MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct")
|
||||
ADAPTER = os.environ.get("ADAPTER_DIR", "outputs/local-lora")
|
||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(BASE, use_fast=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(BASE, torch_dtype=torch.float32)
|
||||
model.to(device)
|
||||
model = PeftModel.from_pretrained(model, ADAPTER)
|
||||
|
||||
prompt = os.environ.get("PROMPT", "Explain gradient accumulation in 3 bullet points.")
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
with torch.no_grad():
|
||||
out = model.generate(**inputs, max_new_tokens=120, do_sample=True, temperature=0.7, top_p=0.9)
|
||||
print(tokenizer.decode(out[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Troubleshooting (macOS-Specific)
|
||||
|
||||
For general training issues, see [troubleshooting.md](troubleshooting.md).
|
||||
|
||||
| Problem | Fix |
|
||||
|---------|-----|
|
||||
| MPS unsupported op / crash | `PYTORCH_ENABLE_MPS_FALLBACK=1` |
|
||||
| OOM / system instability | Reduce `MAX_SEQ_LENGTH`, use smaller model, set `PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0` (caution) |
|
||||
| fp16 NaN / loss explosion | Keep `fp16=False` (default), lower learning rate |
|
||||
| LoRA "module not found" | Print `model.named_modules()` to find correct target names |
|
||||
| TRL TypeError on args | Check TRL version; script uses `SFTConfig` + `processing_class` (TRL ≥0.12) |
|
||||
| Intel Mac | No MPS — use HF Jobs instead |
|
||||
|
||||
**Common LoRA target modules by architecture:**
|
||||
|
||||
| Architecture | target_modules |
|
||||
|-------------|---------------|
|
||||
| Llama/Qwen/Mistral | `q_proj`, `k_proj`, `v_proj`, `o_proj` |
|
||||
| GPT-2/GPT-J | `c_attn`, `c_proj` |
|
||||
| BLOOM | `query_key_value`, `dense` |
|
||||
|
||||
## MLX Alternative
|
||||
|
||||
[MLX](https://github.com/ml-explore/mlx) offers tighter Apple Silicon integration but has a smaller ecosystem and less mature training APIs. For this skill's workflow (local validation → HF Jobs), PyTorch + MPS is recommended for consistency. See [mlx-lm](https://github.com/ml-explore/mlx-lm) for MLX-based fine-tuning.
|
||||
|
||||
## See Also
|
||||
|
||||
- [troubleshooting.md](troubleshooting.md) — General TRL troubleshooting
|
||||
- [hardware_guide.md](hardware_guide.md) — GPU selection for HF Jobs
|
||||
- [gguf_conversion.md](gguf_conversion.md) — Export for on-device inference
|
||||
- [training_methods.md](training_methods.md) — SFT, DPO, GRPO overview
|
||||
@@ -0,0 +1,371 @@
|
||||
# Reliability Principles for Training Jobs
|
||||
|
||||
These principles are derived from real production failures and successful fixes. Following them prevents common failure modes and ensures reliable job execution.
|
||||
|
||||
## Principle 1: Always Verify Before Use
|
||||
|
||||
**Rule:** Never assume repos, datasets, or resources exist. Verify with tools first.
|
||||
|
||||
### What It Prevents
|
||||
|
||||
- **Non-existent datasets** - Jobs fail immediately when dataset doesn't exist
|
||||
- **Typos in names** - Simple mistakes like "argilla-dpo-mix-7k" vs "ultrafeedback_binarized"
|
||||
- **Incorrect paths** - Old or moved repos, renamed files
|
||||
- **Missing dependencies** - Undocumented requirements
|
||||
|
||||
### How to Apply
|
||||
|
||||
**Before submitting ANY job:**
|
||||
|
||||
```python
|
||||
# Verify dataset exists
|
||||
dataset_search({"query": "dataset-name", "author": "author-name", "limit": 5})
|
||||
hub_repo_details(["author/dataset-name"], repo_type="dataset")
|
||||
|
||||
# Verify model exists
|
||||
hub_repo_details(["org/model-name"], repo_type="model")
|
||||
|
||||
# Check script/file paths (for URL-based scripts)
|
||||
# Verify before using: https://github.com/user/repo/blob/main/script.py
|
||||
```
|
||||
|
||||
**Examples that would have caught errors:**
|
||||
|
||||
```python
|
||||
# ❌ WRONG: Assumed dataset exists
|
||||
hf_jobs("uv", {
|
||||
"script": """...""",
|
||||
"env": {"DATASET": "trl-lib/argilla-dpo-mix-7k"} # Doesn't exist!
|
||||
})
|
||||
|
||||
# ✅ CORRECT: Verify first
|
||||
dataset_search({"query": "argilla dpo", "author": "trl-lib"})
|
||||
# Would show: "trl-lib/ultrafeedback_binarized" is the correct name
|
||||
|
||||
hub_repo_details(["trl-lib/ultrafeedback_binarized"], repo_type="dataset")
|
||||
# Confirms it exists before using
|
||||
```
|
||||
|
||||
### Implementation Checklist
|
||||
|
||||
- [ ] Check dataset exists before training
|
||||
- [ ] Verify base model exists before fine-tuning
|
||||
- [ ] Confirm adapter model exists before GGUF conversion
|
||||
- [ ] Test script URLs are valid before submitting
|
||||
- [ ] Validate file paths in repositories
|
||||
- [ ] Check for recent updates/renames of resources
|
||||
|
||||
**Time cost:** 5-10 seconds
|
||||
**Time saved:** Hours of failed job time + debugging
|
||||
|
||||
---
|
||||
|
||||
## Principle 2: Prioritize Reliability Over Performance
|
||||
|
||||
**Rule:** Default to what is most likely to succeed, not what is theoretically fastest.
|
||||
|
||||
### What It Prevents
|
||||
|
||||
- **Hardware incompatibilities** - Features that fail on certain GPUs
|
||||
- **Unstable optimizations** - Speed-ups that cause crashes
|
||||
- **Complex configurations** - More failure points
|
||||
- **Build system issues** - Unreliable compilation methods
|
||||
|
||||
### How to Apply
|
||||
|
||||
**Choose reliability:**
|
||||
|
||||
```python
|
||||
# ❌ RISKY: Aggressive optimization that may fail
|
||||
SFTConfig(
|
||||
torch_compile=True, # Can fail on T4, A10G GPUs
|
||||
optim="adamw_bnb_8bit", # Requires specific setup
|
||||
fp16=False, # May cause training instability
|
||||
...
|
||||
)
|
||||
|
||||
# ✅ SAFE: Proven defaults
|
||||
SFTConfig(
|
||||
# torch_compile=True, # Commented with note: "Enable on H100 for 20% speedup"
|
||||
optim="adamw_torch", # Standard, always works
|
||||
fp16=True, # Stable and fast
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
**For build processes:**
|
||||
|
||||
```python
|
||||
# ❌ UNRELIABLE: Uses make (platform-dependent)
|
||||
subprocess.run(["make", "-C", "/tmp/llama.cpp", "llama-quantize"], check=True)
|
||||
|
||||
# ✅ RELIABLE: Uses CMake (consistent, documented)
|
||||
subprocess.run([
|
||||
"cmake", "-B", "/tmp/llama.cpp/build", "-S", "/tmp/llama.cpp",
|
||||
"-DGGML_CUDA=OFF" # Disable CUDA for faster, more reliable build
|
||||
], check=True)
|
||||
|
||||
subprocess.run([
|
||||
"cmake", "--build", "/tmp/llama.cpp/build",
|
||||
"--target", "llama-quantize", "-j", "4"
|
||||
], check=True)
|
||||
```
|
||||
|
||||
### Real-World Example
|
||||
|
||||
**The `torch.compile` failure:**
|
||||
- Added for "20% speedup" on H100
|
||||
- **Failed fatally on T4-medium** with cryptic error
|
||||
- Misdiagnosed as dataset issue (cost hours)
|
||||
- **Fix:** Disable by default, add as optional comment
|
||||
|
||||
**Result:** Reliability > 20% performance gain
|
||||
|
||||
### Implementation Checklist
|
||||
|
||||
- [ ] Use proven, standard configurations by default
|
||||
- [ ] Comment out performance optimizations with hardware notes
|
||||
- [ ] Use stable build systems (CMake > make)
|
||||
- [ ] Test on target hardware before production
|
||||
- [ ] Document known incompatibilities
|
||||
- [ ] Provide "safe" and "fast" variants when needed
|
||||
|
||||
**Performance loss:** 10-20% in best case
|
||||
**Reliability gain:** 95%+ success rate vs 60-70%
|
||||
|
||||
---
|
||||
|
||||
## Principle 3: Create Atomic, Self-Contained Scripts
|
||||
|
||||
**Rule:** Scripts should work as complete, independent units. Don't remove parts to "simplify."
|
||||
|
||||
### What It Prevents
|
||||
|
||||
- **Missing dependencies** - Removed "unnecessary" packages that are actually required
|
||||
- **Incomplete processes** - Skipped steps that seem redundant
|
||||
- **Environment assumptions** - Scripts that need pre-setup
|
||||
- **Partial failures** - Some parts work, others fail silently
|
||||
|
||||
### How to Apply
|
||||
|
||||
**Complete dependency specifications:**
|
||||
|
||||
```python
|
||||
# ❌ INCOMPLETE: "Simplified" by removing dependencies
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers",
|
||||
# "peft",
|
||||
# "torch",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
# ✅ COMPLETE: All dependencies explicit
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers>=4.36.0",
|
||||
# "peft>=0.7.0",
|
||||
# "torch>=2.0.0",
|
||||
# "accelerate>=0.24.0",
|
||||
# "huggingface_hub>=0.20.0",
|
||||
# "sentencepiece>=0.1.99", # Required for tokenizers
|
||||
# "protobuf>=3.20.0", # Required for tokenizers
|
||||
# "numpy",
|
||||
# "gguf",
|
||||
# ]
|
||||
# ///
|
||||
```
|
||||
|
||||
**Complete build processes:**
|
||||
|
||||
```python
|
||||
# ❌ INCOMPLETE: Assumes build tools exist
|
||||
subprocess.run(["git", "clone", "https://github.com/ggerganov/llama.cpp.git", "/tmp/llama.cpp"])
|
||||
subprocess.run(["make", "-C", "/tmp/llama.cpp", "llama-quantize"]) # FAILS: no gcc/make
|
||||
|
||||
# ✅ COMPLETE: Installs all requirements
|
||||
subprocess.run(["apt-get", "update", "-qq"], check=True)
|
||||
subprocess.run(["apt-get", "install", "-y", "-qq", "build-essential", "cmake"], check=True)
|
||||
subprocess.run(["git", "clone", "https://github.com/ggerganov/llama.cpp.git", "/tmp/llama.cpp"])
|
||||
# ... then build
|
||||
```
|
||||
|
||||
### Real-World Example
|
||||
|
||||
**The `sentencepiece` failure:**
|
||||
- Original script had it: worked fine
|
||||
- "Simplified" version removed it: "doesn't look necessary"
|
||||
- **GGUF conversion failed silently** - tokenizer couldn't convert
|
||||
- Hard to debug: no obvious error message
|
||||
- **Fix:** Restore all original dependencies
|
||||
|
||||
**Result:** Don't remove dependencies without thorough testing
|
||||
|
||||
### Implementation Checklist
|
||||
|
||||
- [ ] All dependencies in PEP 723 header with version pins
|
||||
- [ ] All system packages installed by script
|
||||
- [ ] No assumptions about pre-existing environment
|
||||
- [ ] No "optional" steps that are actually required
|
||||
- [ ] Test scripts in clean environment
|
||||
- [ ] Document why each dependency is needed
|
||||
|
||||
**Complexity:** Slightly longer scripts
|
||||
**Reliability:** Scripts "just work" every time
|
||||
|
||||
---
|
||||
|
||||
## Principle 4: Provide Clear Error Context
|
||||
|
||||
**Rule:** When things fail, make it obvious what went wrong and how to fix it.
|
||||
|
||||
### How to Apply
|
||||
|
||||
**Wrap subprocess calls:**
|
||||
|
||||
```python
|
||||
# ❌ UNCLEAR: Silent failure
|
||||
subprocess.run([...], check=True, capture_output=True)
|
||||
|
||||
# ✅ CLEAR: Shows what failed
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[...],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print("Warnings:", result.stderr)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ Command failed!")
|
||||
print("STDOUT:", e.stdout)
|
||||
print("STDERR:", e.stderr)
|
||||
raise
|
||||
```
|
||||
|
||||
**Validate inputs:**
|
||||
|
||||
```python
|
||||
# ❌ UNCLEAR: Fails later with cryptic error
|
||||
model = load_model(MODEL_NAME)
|
||||
|
||||
# ✅ CLEAR: Fails fast with clear message
|
||||
if not MODEL_NAME:
|
||||
raise ValueError("MODEL_NAME environment variable not set!")
|
||||
|
||||
print(f"Loading model: {MODEL_NAME}")
|
||||
try:
|
||||
model = load_model(MODEL_NAME)
|
||||
print(f"✅ Model loaded successfully")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to load model: {MODEL_NAME}")
|
||||
print(f"Error: {e}")
|
||||
print("Hint: Check that model exists on Hub")
|
||||
raise
|
||||
```
|
||||
|
||||
### Implementation Checklist
|
||||
|
||||
- [ ] Wrap external calls with try/except
|
||||
- [ ] Print stdout/stderr on failure
|
||||
- [ ] Validate environment variables early
|
||||
- [ ] Add progress indicators (✅, ❌, 🔄)
|
||||
- [ ] Include hints for common failures
|
||||
- [ ] Log configuration at start
|
||||
|
||||
---
|
||||
|
||||
## Principle 5: Test the Happy Path on Known-Good Inputs
|
||||
|
||||
**Rule:** Before using new code in production, test with inputs you know work.
|
||||
|
||||
### How to Apply
|
||||
|
||||
**Known-good test inputs:**
|
||||
|
||||
```python
|
||||
# For training
|
||||
TEST_DATASET = "trl-lib/Capybara" # Small, well-formatted, widely used
|
||||
TEST_MODEL = "Qwen/Qwen2.5-0.5B" # Small, fast, reliable
|
||||
|
||||
# For GGUF conversion
|
||||
TEST_ADAPTER = "evalstate/qwen-capybara-medium" # Known working model
|
||||
TEST_BASE = "Qwen/Qwen2.5-0.5B" # Compatible base
|
||||
```
|
||||
|
||||
**Testing workflow:**
|
||||
|
||||
1. Test with known-good inputs first
|
||||
2. If that works, try production inputs
|
||||
3. If production fails, you know it's the inputs (not code)
|
||||
4. Isolate the difference
|
||||
|
||||
### Implementation Checklist
|
||||
|
||||
- [ ] Maintain list of known-good test models/datasets
|
||||
- [ ] Test new scripts with test inputs first
|
||||
- [ ] Document what makes inputs "good"
|
||||
- [ ] Keep test jobs cheap (small models, short timeouts)
|
||||
- [ ] Only move to production after test succeeds
|
||||
|
||||
**Time cost:** 5-10 minutes for test run
|
||||
**Debugging time saved:** Hours
|
||||
|
||||
---
|
||||
|
||||
## Summary: The Reliability Checklist
|
||||
|
||||
Before submitting ANY job:
|
||||
|
||||
### Pre-Flight Checks
|
||||
- [ ] **Verified** all repos/datasets exist (hub_repo_details)
|
||||
- [ ] **Tested** with known-good inputs if new code
|
||||
- [ ] **Using** proven hardware/configuration
|
||||
- [ ] **Included** all dependencies in PEP 723 header
|
||||
- [ ] **Installed** system requirements (build tools, etc.)
|
||||
- [ ] **Set** appropriate timeout (not default 30m)
|
||||
- [ ] **Configured** Hub push with HF_TOKEN
|
||||
- [ ] **Added** clear error handling
|
||||
|
||||
### Script Quality
|
||||
- [ ] Self-contained (no external setup needed)
|
||||
- [ ] Complete dependencies listed
|
||||
- [ ] Build tools installed by script
|
||||
- [ ] Progress indicators included
|
||||
- [ ] Error messages are clear
|
||||
- [ ] Configuration logged at start
|
||||
|
||||
### Job Configuration
|
||||
- [ ] Timeout > expected runtime + 30% buffer
|
||||
- [ ] Hardware appropriate for model size
|
||||
- [ ] Secrets include HF_TOKEN
|
||||
- [ ] Environment variables set correctly
|
||||
- [ ] Cost estimated and acceptable
|
||||
|
||||
**Following these principles transforms job success rate from ~60-70% to ~95%+**
|
||||
|
||||
---
|
||||
|
||||
## When Principles Conflict
|
||||
|
||||
Sometimes reliability and performance conflict. Here's how to choose:
|
||||
|
||||
| Scenario | Choose | Rationale |
|
||||
|----------|--------|-----------|
|
||||
| Demo/test | Reliability | Fast failure is worse than slow success |
|
||||
| Production (first run) | Reliability | Prove it works before optimizing |
|
||||
| Production (proven) | Performance | Safe to optimize after validation |
|
||||
| Time-critical | Reliability | Failures cause more delay than slow runs |
|
||||
| Cost-critical | Balanced | Test with small model, then optimize |
|
||||
|
||||
**General rule:** Reliability first, optimize second.
|
||||
|
||||
---
|
||||
|
||||
## Further Reading
|
||||
|
||||
- `troubleshooting.md` - Common issues and fixes
|
||||
- `training_patterns.md` - Proven training configurations
|
||||
- `gguf_conversion.md` - Production GGUF workflow
|
||||
@@ -0,0 +1,189 @@
|
||||
# Trackio Integration for TRL Training
|
||||
|
||||
**Trackio** is an experiment tracking library that provides real-time metrics visualization for remote training on Hugging Face Jobs infrastructure.
|
||||
|
||||
⚠️ **IMPORTANT**: For Jobs training (remote cloud GPUs):
|
||||
- Training happens on ephemeral cloud runners (not your local machine)
|
||||
- Trackio syncs metrics to a Hugging Face Space for real-time monitoring
|
||||
- Without a Space, metrics are lost when the job completes
|
||||
- The Space dashboard persists your training metrics permanently
|
||||
|
||||
## Setting Up Trackio for Jobs
|
||||
|
||||
**Step 1: Add trackio dependency**
|
||||
```python
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl>=0.12.0",
|
||||
# "trackio", # Required!
|
||||
# ]
|
||||
# ///
|
||||
```
|
||||
|
||||
**Step 2: Create a Trackio Space (one-time setup)**
|
||||
|
||||
**Option A: Let Trackio auto-create (Recommended)**
|
||||
Pass a `space_id` to `trackio.init()` and Trackio will automatically create the Space if it doesn't exist.
|
||||
|
||||
**Option B: Create manually**
|
||||
- Create Space via Hub UI at https://huggingface.co/new-space
|
||||
- Select Gradio SDK
|
||||
- OR use command: `hf repos create my-trackio-dashboard --type space --space-sdk gradio`
|
||||
|
||||
**Step 3: Initialize Trackio with space_id**
|
||||
```python
|
||||
import trackio
|
||||
|
||||
trackio.init(
|
||||
project="my-training",
|
||||
space_id="username/trackio", # CRITICAL for Jobs! Replace 'username' with your HF username
|
||||
config={
|
||||
"model": "Qwen/Qwen2.5-0.5B",
|
||||
"dataset": "trl-lib/Capybara",
|
||||
"learning_rate": 2e-5,
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**Step 4: Configure TRL to use Trackio**
|
||||
```python
|
||||
SFTConfig(
|
||||
report_to="trackio",
|
||||
# ... other config
|
||||
)
|
||||
```
|
||||
|
||||
**Step 5: Finish tracking**
|
||||
```python
|
||||
trainer.train()
|
||||
trackio.finish() # Ensures final metrics are synced
|
||||
```
|
||||
|
||||
## What Trackio Tracks
|
||||
|
||||
Trackio automatically logs:
|
||||
- ✅ Training loss
|
||||
- ✅ Learning rate
|
||||
- ✅ GPU utilization
|
||||
- ✅ Memory usage
|
||||
- ✅ Training throughput
|
||||
- ✅ Custom metrics
|
||||
|
||||
## How It Works with Jobs
|
||||
|
||||
1. **Training runs** → Metrics logged to local SQLite DB
|
||||
2. **Every 5 minutes** → Trackio syncs DB to HF Dataset (Parquet)
|
||||
3. **Space dashboard** → Reads from Dataset, displays metrics in real-time
|
||||
4. **Job completes** → Final sync ensures all metrics persisted
|
||||
|
||||
## Default Configuration Pattern
|
||||
|
||||
**Use sensible defaults for trackio configuration unless user requests otherwise.**
|
||||
|
||||
### Recommended Defaults
|
||||
|
||||
```python
|
||||
import trackio
|
||||
|
||||
trackio.init(
|
||||
project="qwen-capybara-sft",
|
||||
name="baseline-run", # Descriptive name user will recognize
|
||||
space_id="username/trackio", # Default space: {username}/trackio
|
||||
config={
|
||||
# Keep config minimal - hyperparameters and model/dataset info only
|
||||
"model": "Qwen/Qwen2.5-0.5B",
|
||||
"dataset": "trl-lib/Capybara",
|
||||
"learning_rate": 2e-5,
|
||||
"num_epochs": 3,
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**Key principles:**
|
||||
- **Space ID**: Use `{username}/trackio` with "trackio" as default space name
|
||||
- **Run naming**: Unless otherwise specified, name the run in a way the user will recognize
|
||||
- **Config**: Keep minimal - don't automatically capture job metadata unless requested
|
||||
- **Grouping**: Optional - only use if user requests organizing related experiments
|
||||
|
||||
## Grouping Runs (Optional)
|
||||
|
||||
The `group` parameter helps organize related runs together in the dashboard sidebar. This is useful when user is running multiple experiments with different configurations but wants to compare them together:
|
||||
|
||||
```python
|
||||
# Example: Group runs by experiment type
|
||||
trackio.init(project="my-project", run_name="baseline-run-1", group="baseline")
|
||||
trackio.init(project="my-project", run_name="augmented-run-1", group="augmented")
|
||||
trackio.init(project="my-project", run_name="tuned-run-1", group="tuned")
|
||||
```
|
||||
|
||||
Runs with the same group name can be grouped together in the sidebar, making it easier to compare related experiments. You can group by any configuration parameter:
|
||||
|
||||
```python
|
||||
# Hyperparameter sweep - group by learning rate
|
||||
trackio.init(project="hyperparam-sweep", run_name="lr-0.001-run", group="lr_0.001")
|
||||
trackio.init(project="hyperparam-sweep", run_name="lr-0.01-run", group="lr_0.01")
|
||||
```
|
||||
|
||||
## Environment Variables for Jobs
|
||||
|
||||
You can configure trackio using environment variables instead of passing parameters to `trackio.init()`. This is useful for managing configuration across multiple jobs.
|
||||
|
||||
|
||||
|
||||
**`HF_TOKEN`**
|
||||
Required for creating Spaces and writing to datasets (passed via `secrets`):
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "...",
|
||||
"secrets": {
|
||||
"HF_TOKEN": "$HF_TOKEN" # Enables Space creation and Hub push
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### Example with Environment Variables
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": """
|
||||
# Training script - trackio config from environment
|
||||
import trackio
|
||||
from datetime import datetime
|
||||
|
||||
# Auto-generate run name
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M")
|
||||
run_name = f"sft_qwen25_{timestamp}"
|
||||
|
||||
# Project and space_id can come from environment variables
|
||||
trackio.init(run_name=run_name, group="SFT")
|
||||
|
||||
# ... training code ...
|
||||
trackio.finish()
|
||||
""",
|
||||
"flavor": "a10g-large",
|
||||
"timeout": "2h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
})
|
||||
```
|
||||
|
||||
**When to use environment variables:**
|
||||
- Managing multiple jobs with same configuration
|
||||
- Keeping training scripts portable across projects
|
||||
- Separating configuration from code
|
||||
|
||||
**When to use direct parameters:**
|
||||
- Single job with specific configuration
|
||||
- When clarity in code is preferred
|
||||
- When each job has different project/space
|
||||
|
||||
## Viewing the Dashboard
|
||||
|
||||
After starting training:
|
||||
1. Navigate to the Space: `https://huggingface.co/spaces/username/trackio`
|
||||
2. The Gradio dashboard shows all tracked experiments
|
||||
3. Filter by project, compare runs, view charts with smoothing
|
||||
|
||||
## Recommendation
|
||||
|
||||
- **Trackio**: Best for real-time monitoring during long training runs
|
||||
- **Weights & Biases**: Best for team collaboration, requires account
|
||||
@@ -0,0 +1,150 @@
|
||||
# TRL Training Methods Overview
|
||||
|
||||
TRL (Transformer Reinforcement Learning) provides multiple training methods for fine-tuning and aligning language models. This reference provides a brief overview of each method.
|
||||
|
||||
## Supervised Fine-Tuning (SFT)
|
||||
|
||||
**What it is:** Standard instruction tuning with supervised learning on demonstration data.
|
||||
|
||||
**When to use:**
|
||||
- Initial fine-tuning of base models on task-specific data
|
||||
- Teaching new capabilities or domains
|
||||
- Most common starting point for fine-tuning
|
||||
|
||||
**Dataset format:** Conversational format with "messages" field, OR text field, OR prompt/completion pairs
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset,
|
||||
args=SFTConfig(
|
||||
output_dir="my-model",
|
||||
push_to_hub=True,
|
||||
hub_model_id="username/my-model",
|
||||
eval_strategy="no", # Disable eval for simple example
|
||||
# max_length=1024 is the default - only set if you need different length
|
||||
)
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
**Note:** For production training with evaluation monitoring, see `scripts/train_sft_example.py`
|
||||
|
||||
**Documentation:** `hf_doc_fetch("https://huggingface.co/docs/trl/sft_trainer")`
|
||||
|
||||
## Direct Preference Optimization (DPO)
|
||||
|
||||
**What it is:** Alignment method that trains directly on preference pairs (chosen vs rejected responses) without requiring a reward model.
|
||||
|
||||
**When to use:**
|
||||
- Aligning models to human preferences
|
||||
- Improving response quality after SFT
|
||||
- Have paired preference data (chosen/rejected responses)
|
||||
|
||||
**Dataset format:** Preference pairs with "chosen" and "rejected" fields
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
from trl import DPOTrainer, DPOConfig
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B-Instruct", # Use instruct model
|
||||
train_dataset=dataset,
|
||||
args=DPOConfig(
|
||||
output_dir="dpo-model",
|
||||
beta=0.1, # KL penalty coefficient
|
||||
eval_strategy="no", # Disable eval for simple example
|
||||
# max_length=1024 is the default - only set if you need different length
|
||||
)
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
**Note:** For production training with evaluation monitoring, see `scripts/train_dpo_example.py`
|
||||
|
||||
**Documentation:** `hf_doc_fetch("https://huggingface.co/docs/trl/dpo_trainer")`
|
||||
|
||||
## Group Relative Policy Optimization (GRPO)
|
||||
|
||||
**What it is:** Online RL method that optimizes relative to group performance, useful for tasks with verifiable rewards.
|
||||
|
||||
**When to use:**
|
||||
- Tasks with automatic reward signals (code execution, math verification)
|
||||
- Online learning scenarios
|
||||
- When DPO offline data is insufficient
|
||||
|
||||
**Dataset format:** Prompt-only format (model generates responses, reward computed online)
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
# Use TRL maintained script
|
||||
hf_jobs("uv", {
|
||||
"script": "https://raw.githubusercontent.com/huggingface/trl/main/examples/scripts/grpo.py",
|
||||
"script_args": [
|
||||
"--model_name_or_path", "Qwen/Qwen2.5-0.5B-Instruct",
|
||||
"--dataset_name", "trl-lib/math_shepherd",
|
||||
"--output_dir", "grpo-model"
|
||||
],
|
||||
"flavor": "a10g-large",
|
||||
"timeout": "4h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
})
|
||||
```
|
||||
|
||||
**Documentation:** `hf_doc_fetch("https://huggingface.co/docs/trl/grpo_trainer")`
|
||||
|
||||
## Reward Modeling
|
||||
|
||||
**What it is:** Train a reward model to score responses, used as a component in RLHF pipelines.
|
||||
|
||||
**When to use:**
|
||||
- Building RLHF pipeline
|
||||
- Need automatic quality scoring
|
||||
- Creating reward signals for PPO training
|
||||
|
||||
**Dataset format:** Preference pairs with "chosen" and "rejected" responses
|
||||
|
||||
**Documentation:** `hf_doc_fetch("https://huggingface.co/docs/trl/reward_trainer")`
|
||||
|
||||
## Method Selection Guide
|
||||
|
||||
| Method | Complexity | Data Required | Use Case |
|
||||
|--------|-----------|---------------|----------|
|
||||
| **SFT** | Low | Demonstrations | Initial fine-tuning |
|
||||
| **DPO** | Medium | Paired preferences | Post-SFT alignment |
|
||||
| **GRPO** | Medium | Prompts + reward fn | Online RL with automatic rewards |
|
||||
| **Reward** | Medium | Paired preferences | Building RLHF pipeline |
|
||||
|
||||
## Recommended Pipeline
|
||||
|
||||
**For most use cases:**
|
||||
1. **Start with SFT** - Fine-tune base model on task data
|
||||
2. **Follow with DPO** - Align to preferences using paired data
|
||||
3. **Optional: GGUF conversion** - Deploy for local inference
|
||||
|
||||
**For advanced RL scenarios:**
|
||||
1. **Start with SFT** - Fine-tune base model
|
||||
2. **Train reward model** - On preference data
|
||||
|
||||
## Dataset Format Reference
|
||||
|
||||
For complete dataset format specifications, use:
|
||||
```python
|
||||
hf_doc_fetch("https://huggingface.co/docs/trl/dataset_formats")
|
||||
```
|
||||
|
||||
Or validate your dataset:
|
||||
```bash
|
||||
uv run https://huggingface.co/datasets/mcp-tools/skills/raw/main/dataset_inspector.py \
|
||||
--dataset your/dataset --split train
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/training_patterns.md` - Common training patterns and examples
|
||||
- `scripts/train_sft_example.py` - Complete SFT template
|
||||
- `scripts/train_dpo_example.py` - Complete DPO template
|
||||
- [Dataset Inspector](https://huggingface.co/datasets/mcp-tools/skills/raw/main/dataset_inspector.py) - Dataset format validation tool
|
||||
@@ -0,0 +1,203 @@
|
||||
# Common Training Patterns
|
||||
|
||||
This guide provides common training patterns and use cases for TRL on Hugging Face Jobs.
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
Automatic distributed training across multiple GPUs. TRL/Accelerate handles distribution automatically:
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": """
|
||||
# Your training script here (same as single GPU)
|
||||
# No changes needed - Accelerate detects multiple GPUs
|
||||
""",
|
||||
"flavor": "a10g-largex2", # 2x A10G GPUs
|
||||
"timeout": "4h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
})
|
||||
```
|
||||
|
||||
**Tips for multi-GPU:**
|
||||
- No code changes needed
|
||||
- Use `per_device_train_batch_size` (per GPU, not total)
|
||||
- Effective batch size = `per_device_train_batch_size` × `num_gpus` × `gradient_accumulation_steps`
|
||||
- Monitor GPU utilization to ensure both GPUs are being used
|
||||
|
||||
## DPO Training (Preference Learning)
|
||||
|
||||
Train with preference data for alignment:
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": """
|
||||
# /// script
|
||||
# dependencies = ["trl>=0.12.0", "trackio"]
|
||||
# ///
|
||||
|
||||
from datasets import load_dataset
|
||||
from trl import DPOTrainer, DPOConfig
|
||||
import trackio
|
||||
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
# Create train/eval split
|
||||
dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
|
||||
|
||||
config = DPOConfig(
|
||||
output_dir="dpo-model",
|
||||
push_to_hub=True,
|
||||
hub_model_id="username/dpo-model",
|
||||
num_train_epochs=1,
|
||||
beta=0.1, # KL penalty coefficient
|
||||
eval_strategy="steps",
|
||||
eval_steps=50,
|
||||
report_to="trackio",
|
||||
run_name="baseline_run", # use a meaningful run name
|
||||
# max_length=1024, # Default - only set if you need different sequence length
|
||||
)
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B-Instruct", # Use instruct model as base
|
||||
train_dataset=dataset_split["train"],
|
||||
eval_dataset=dataset_split["test"], # IMPORTANT: Provide eval_dataset when eval_strategy is enabled
|
||||
args=config,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.push_to_hub()
|
||||
trackio.finish()
|
||||
""",
|
||||
"flavor": "a10g-large",
|
||||
"timeout": "3h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
})
|
||||
```
|
||||
|
||||
**For DPO documentation:** Use `hf_doc_fetch("https://huggingface.co/docs/trl/dpo_trainer")`
|
||||
|
||||
## GRPO Training (Online RL)
|
||||
|
||||
Group Relative Policy Optimization for online reinforcement learning:
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "https://raw.githubusercontent.com/huggingface/trl/main/examples/scripts/grpo.py",
|
||||
"script_args": [
|
||||
"--model_name_or_path", "Qwen/Qwen2.5-0.5B-Instruct",
|
||||
"--dataset_name", "trl-lib/math_shepherd",
|
||||
"--output_dir", "grpo-model",
|
||||
"--push_to_hub",
|
||||
"--hub_model_id", "username/grpo-model"
|
||||
],
|
||||
"flavor": "a10g-large",
|
||||
"timeout": "4h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
})
|
||||
```
|
||||
|
||||
**For GRPO documentation:** Use `hf_doc_fetch("https://huggingface.co/docs/trl/grpo_trainer")`
|
||||
|
||||
## Trackio Configuration
|
||||
|
||||
**Use sensible defaults for trackio setup.** See `references/trackio_guide.md` for complete documentation including grouping runs for experiments.
|
||||
|
||||
### Basic Pattern
|
||||
|
||||
```python
|
||||
import trackio
|
||||
|
||||
trackio.init(
|
||||
project="my-training",
|
||||
run_name="baseline-run", # Descriptive name user will recognize
|
||||
space_id="username/trackio", # Default space: {username}/trackio
|
||||
config={
|
||||
# Keep config minimal - hyperparameters and model/dataset info only
|
||||
"model": "Qwen/Qwen2.5-0.5B",
|
||||
"dataset": "trl-lib/Capybara",
|
||||
"learning_rate": 2e-5,
|
||||
}
|
||||
)
|
||||
|
||||
# Your training code...
|
||||
|
||||
trackio.finish()
|
||||
```
|
||||
|
||||
### Grouping for Experiments (Optional)
|
||||
|
||||
When user wants to compare related runs, use the `group` parameter:
|
||||
|
||||
```python
|
||||
# Hyperparameter sweep
|
||||
trackio.init(project="hyperparam-sweep", run_name="lr-0.001", group="lr_0.001")
|
||||
trackio.init(project="hyperparam-sweep", run_name="lr-0.01", group="lr_0.01")
|
||||
```
|
||||
|
||||
## Pattern Selection Guide
|
||||
|
||||
| Use Case | Pattern | Hardware | Time |
|
||||
|----------|---------|----------|------|
|
||||
| SFT training | `scripts/train_sft_example.py` | a10g-large | 2-6 hours |
|
||||
| Large dataset (>10K) | Multi-GPU | a10g-largex2 | 4-12 hours |
|
||||
| Preference learning | DPO Training | a10g-large | 2-4 hours |
|
||||
| Online RL | GRPO Training | a10g-large | 3-6 hours |
|
||||
|
||||
## Critical: Evaluation Dataset Requirements
|
||||
|
||||
**⚠️ IMPORTANT**: If you set `eval_strategy="steps"` or `eval_strategy="epoch"`, you **MUST** provide an `eval_dataset` to the trainer, or the training will hang.
|
||||
|
||||
### ✅ CORRECT - With eval dataset:
|
||||
```python
|
||||
dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset_split["train"],
|
||||
eval_dataset=dataset_split["test"], # ← MUST provide when eval_strategy is enabled
|
||||
args=SFTConfig(eval_strategy="steps", ...),
|
||||
)
|
||||
```
|
||||
|
||||
### ❌ WRONG - Will hang:
|
||||
```python
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset,
|
||||
# NO eval_dataset but eval_strategy="steps" ← WILL HANG
|
||||
args=SFTConfig(eval_strategy="steps", ...),
|
||||
)
|
||||
```
|
||||
|
||||
### Option: Disable evaluation if no eval dataset
|
||||
```python
|
||||
config = SFTConfig(
|
||||
eval_strategy="no", # ← Explicitly disable evaluation
|
||||
# ... other config
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset,
|
||||
# No eval_dataset needed
|
||||
args=config,
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use train/eval splits** - Create evaluation split for monitoring progress
|
||||
2. **Enable Trackio** - Monitor progress in real-time
|
||||
3. **Add 20-30% buffer to timeout** - Account for loading/saving overhead
|
||||
4. **Test with TRL official scripts first** - Use maintained examples before custom code
|
||||
5. **Always provide eval_dataset** - When using eval_strategy, or set to "no"
|
||||
6. **Use multi-GPU for large models** - 7B+ models benefit significantly
|
||||
|
||||
## See Also
|
||||
|
||||
- `scripts/train_sft_example.py` - Complete SFT template with Trackio and eval split
|
||||
- `scripts/train_dpo_example.py` - Complete DPO template
|
||||
- `scripts/train_grpo_example.py` - Complete GRPO template
|
||||
- `references/hardware_guide.md` - Detailed hardware specifications
|
||||
- `references/training_methods.md` - Overview of all TRL training methods
|
||||
- `references/troubleshooting.md` - Common issues and solutions
|
||||
@@ -0,0 +1,282 @@
|
||||
# Troubleshooting TRL Training Jobs
|
||||
|
||||
Common issues and solutions when training with TRL on Hugging Face Jobs.
|
||||
|
||||
## Training Hangs at "Starting training..." Step
|
||||
|
||||
**Problem:** Job starts but hangs at the training step - never progresses, never times out, just sits there.
|
||||
|
||||
**Root Cause:** Using `eval_strategy="steps"` or `eval_strategy="epoch"` without providing an `eval_dataset` to the trainer.
|
||||
|
||||
**Solution:**
|
||||
|
||||
**Option A: Provide eval_dataset (recommended)**
|
||||
```python
|
||||
# Create train/eval split
|
||||
dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset_split["train"],
|
||||
eval_dataset=dataset_split["test"], # ← MUST provide when eval_strategy is enabled
|
||||
args=SFTConfig(
|
||||
eval_strategy="steps",
|
||||
eval_steps=50,
|
||||
...
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
**Option B: Disable evaluation**
|
||||
```python
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset,
|
||||
# No eval_dataset
|
||||
args=SFTConfig(
|
||||
eval_strategy="no", # ← Explicitly disable
|
||||
...
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
**Prevention:**
|
||||
- Always create train/eval split for better monitoring
|
||||
- Use `dataset.train_test_split(test_size=0.1, seed=42)`
|
||||
- Check example scripts: `scripts/train_sft_example.py` includes proper eval setup
|
||||
|
||||
## Job Times Out
|
||||
|
||||
**Problem:** Job terminates before training completes, all progress lost.
|
||||
|
||||
**Solutions:**
|
||||
- Increase timeout parameter (e.g., `"timeout": "4h"`)
|
||||
- Reduce `num_train_epochs` or use smaller dataset slice
|
||||
- Use smaller model or enable LoRA/PEFT to speed up training
|
||||
- Add 20-30% buffer to estimated time for loading/saving overhead
|
||||
|
||||
**Prevention:**
|
||||
- Always start with a quick demo run to estimate timing
|
||||
- Use `scripts/estimate_cost.py` to get time estimates
|
||||
- Monitor first runs closely via Trackio or logs
|
||||
|
||||
## Model Not Saved to Hub
|
||||
|
||||
**Problem:** Training completes but model doesn't appear on Hub - all work lost.
|
||||
|
||||
**Check:**
|
||||
- [ ] `push_to_hub=True` in training config
|
||||
- [ ] `hub_model_id` specified with username (e.g., `"username/model-name"`)
|
||||
- [ ] `secrets={"HF_TOKEN": "$HF_TOKEN"}` in job submission
|
||||
- [ ] User has write access to target repo
|
||||
- [ ] Token has write permissions (check at https://huggingface.co/settings/tokens)
|
||||
- [ ] Training script calls `trainer.push_to_hub()` at the end
|
||||
|
||||
**See:** `references/hub_saving.md` for detailed Hub authentication troubleshooting
|
||||
|
||||
## Out of Memory (OOM)
|
||||
|
||||
**Problem:** Job fails with CUDA out of memory error.
|
||||
|
||||
**Solutions (in order of preference):**
|
||||
1. **Reduce batch size:** Lower `per_device_train_batch_size` (try 4 → 2 → 1)
|
||||
2. **Increase gradient accumulation:** Raise `gradient_accumulation_steps` to maintain effective batch size
|
||||
3. **Disable evaluation:** Remove `eval_dataset` and `eval_strategy` (saves ~40% memory, good for demos)
|
||||
4. **Enable LoRA/PEFT:** Use `peft_config=LoraConfig(r=8, lora_alpha=16)` to train adapters only (smaller rank = less memory)
|
||||
5. **Use larger GPU:** Switch from `t4-small` → `l4x1` → `a10g-large` → `a100-large`
|
||||
6. **Enable gradient checkpointing:** Set `gradient_checkpointing=True` in config (slower but saves memory)
|
||||
7. **Use smaller model:** Try a smaller variant (e.g., 0.5B instead of 3B)
|
||||
|
||||
**Memory guidelines:**
|
||||
- T4 (16GB): <1B models with LoRA
|
||||
- A10G (24GB): 1-3B models with LoRA, <1B full fine-tune
|
||||
- A100 (40GB/80GB): 7B+ models with LoRA, 3B full fine-tune
|
||||
|
||||
## Parameter Naming Issues
|
||||
|
||||
**Problem:** `TypeError: SFTConfig.__init__() got an unexpected keyword argument 'max_seq_length'`
|
||||
|
||||
**Cause:** TRL config classes use `max_length`, not `max_seq_length`.
|
||||
|
||||
**Solution:**
|
||||
```python
|
||||
# ✅ CORRECT - TRL uses max_length
|
||||
SFTConfig(max_length=512)
|
||||
DPOConfig(max_length=512)
|
||||
|
||||
# ❌ WRONG - This will fail
|
||||
SFTConfig(max_seq_length=512)
|
||||
```
|
||||
|
||||
**Note:** Most TRL configs don't require explicit max_length - the default (1024) works well. Only set if you need a specific value.
|
||||
|
||||
## Dataset Format Error
|
||||
|
||||
**Problem:** Training fails with dataset format errors or missing fields.
|
||||
|
||||
**Solutions:**
|
||||
1. **Check format documentation:**
|
||||
```python
|
||||
hf_doc_fetch("https://huggingface.co/docs/trl/dataset_formats")
|
||||
```
|
||||
|
||||
2. **Validate dataset before training:**
|
||||
```bash
|
||||
uv run https://huggingface.co/datasets/mcp-tools/skills/raw/main/dataset_inspector.py \
|
||||
--dataset <dataset-name> --split train
|
||||
```
|
||||
Or via hf_jobs:
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "https://huggingface.co/datasets/mcp-tools/skills/raw/main/dataset_inspector.py",
|
||||
"script_args": ["--dataset", "dataset-name", "--split", "train"]
|
||||
})
|
||||
```
|
||||
|
||||
3. **Verify field names:**
|
||||
- **SFT:** Needs "messages" field (conversational), OR "text" field, OR "prompt"/"completion"
|
||||
- **DPO:** Needs "chosen" and "rejected" fields
|
||||
- **GRPO:** Needs prompt-only format
|
||||
|
||||
4. **Check dataset split:**
|
||||
- Ensure split exists (e.g., `split="train"`)
|
||||
- Preview dataset: `load_dataset("name", split="train[:5]")`
|
||||
|
||||
## Import/Module Errors
|
||||
|
||||
**Problem:** Job fails with "ModuleNotFoundError" or import errors.
|
||||
|
||||
**Solutions:**
|
||||
1. **Add PEP 723 header with dependencies:**
|
||||
```python
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl>=0.12.0",
|
||||
# "peft>=0.7.0",
|
||||
# "transformers>=4.36.0",
|
||||
# ]
|
||||
# ///
|
||||
```
|
||||
|
||||
2. **Verify exact format:**
|
||||
- Must have `# ///` delimiters (with space after `#`)
|
||||
- Dependencies must be valid PyPI package names
|
||||
- Check spelling and version constraints
|
||||
|
||||
3. **Test locally first:**
|
||||
```bash
|
||||
uv run train.py # Tests if dependencies are correct
|
||||
```
|
||||
|
||||
## Authentication Errors
|
||||
|
||||
**Problem:** Job fails with authentication or permission errors when pushing to Hub.
|
||||
|
||||
**Solutions:**
|
||||
1. **Verify authentication:**
|
||||
```python
|
||||
mcp__huggingface__hf_whoami() # Check who's authenticated
|
||||
```
|
||||
|
||||
2. **Check token permissions:**
|
||||
- Go to https://huggingface.co/settings/tokens
|
||||
- Ensure token has "write" permission
|
||||
- Token must not be "read-only"
|
||||
|
||||
3. **Verify token in job:**
|
||||
```python
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"} # Must be in job config
|
||||
```
|
||||
|
||||
4. **Check repo permissions:**
|
||||
- User must have write access to target repo
|
||||
- If org repo, user must be member with write access
|
||||
- Repo must exist or user must have permission to create
|
||||
|
||||
## Job Stuck or Not Starting
|
||||
|
||||
**Problem:** Job shows "pending" or "starting" for extended period.
|
||||
|
||||
**Solutions:**
|
||||
- Check Jobs dashboard for status: https://huggingface.co/jobs
|
||||
- Verify hardware availability (some GPU types may have queues)
|
||||
- Try different hardware flavor if one is heavily utilized
|
||||
- Check for account billing issues (Jobs requires paid plan)
|
||||
|
||||
**Typical startup times:**
|
||||
- CPU jobs: 10-30 seconds
|
||||
- GPU jobs: 30-90 seconds
|
||||
- If >3 minutes: likely queued or stuck
|
||||
|
||||
## Training Loss Not Decreasing
|
||||
|
||||
**Problem:** Training runs but loss stays flat or doesn't improve.
|
||||
|
||||
**Solutions:**
|
||||
1. **Check learning rate:** May be too low (try 2e-5 to 5e-5) or too high (try 1e-6)
|
||||
2. **Verify dataset quality:** Inspect examples to ensure they're reasonable
|
||||
3. **Check model size:** Very small models may not have capacity for task
|
||||
4. **Increase training steps:** May need more epochs or larger dataset
|
||||
5. **Verify dataset format:** Wrong format may cause degraded training
|
||||
|
||||
## Logs Not Appearing
|
||||
|
||||
**Problem:** Cannot see training logs or progress.
|
||||
|
||||
**Solutions:**
|
||||
1. **Wait 30-60 seconds:** Initial logs can be delayed
|
||||
2. **Check logs via MCP tool:**
|
||||
```python
|
||||
hf_jobs("logs", {"job_id": "your-job-id"})
|
||||
```
|
||||
3. **Use Trackio for real-time monitoring:** See `references/trackio_guide.md`
|
||||
4. **Verify job is actually running:**
|
||||
```python
|
||||
hf_jobs("inspect", {"job_id": "your-job-id"})
|
||||
```
|
||||
|
||||
## Checkpoint/Resume Issues
|
||||
|
||||
**Problem:** Cannot resume from checkpoint or checkpoint not saved.
|
||||
|
||||
**Solutions:**
|
||||
1. **Enable checkpoint saving:**
|
||||
```python
|
||||
SFTConfig(
|
||||
save_strategy="steps",
|
||||
save_steps=100,
|
||||
hub_strategy="every_save", # Push each checkpoint
|
||||
)
|
||||
```
|
||||
|
||||
2. **Verify checkpoints pushed to Hub:** Check model repo for checkpoint folders
|
||||
|
||||
3. **Resume from checkpoint:**
|
||||
```python
|
||||
trainer = SFTTrainer(
|
||||
model="username/model-name", # Can be checkpoint path
|
||||
resume_from_checkpoint="username/model-name/checkpoint-1000",
|
||||
)
|
||||
```
|
||||
|
||||
## Getting Help
|
||||
|
||||
If issues persist:
|
||||
|
||||
1. **Check TRL documentation:**
|
||||
```python
|
||||
hf_doc_search("your issue", product="trl")
|
||||
```
|
||||
|
||||
2. **Check Jobs documentation:**
|
||||
```python
|
||||
hf_doc_fetch("https://huggingface.co/docs/huggingface_hub/guides/jobs")
|
||||
```
|
||||
|
||||
3. **Review related guides:**
|
||||
- `references/hub_saving.md` - Hub authentication issues
|
||||
- `references/hardware_guide.md` - Hardware selection and specs
|
||||
- `references/training_patterns.md` - Eval dataset requirements
|
||||
- SKILL.md "Working with Scripts" section - Script format and URL issues
|
||||
|
||||
4. **Ask in HF forums:** https://discuss.huggingface.co/
|
||||
@@ -0,0 +1,313 @@
|
||||
# Unsloth: Fast Fine-Tuning with Memory Optimization
|
||||
|
||||
**Unsloth** is a fine-tuning library that provides ~2x faster training and ~60% less VRAM usage for LLM training. It's particularly useful when working with limited GPU memory or when speed is critical.
|
||||
|
||||
- **GitHub**: [unslothai/unsloth](https://github.com/unslothai/unsloth)
|
||||
- **Docs**: [unsloth.ai/docs](https://unsloth.ai/docs)
|
||||
|
||||
## When to Use Unsloth
|
||||
|
||||
Use Unsloth if instructed to do so, or one of the following use cases applies:
|
||||
|
||||
| Use Case | Recommendation |
|
||||
|----------|----------------|
|
||||
| Standard text LLM fine-tuning | TRL is sufficient, but Unsloth is faster |
|
||||
| Limited GPU memory | **Use Unsloth** - 60% less VRAM |
|
||||
| Need maximum speed | **Use Unsloth** - 2x faster |
|
||||
| Large models (>13B) | **Use Unsloth** - memory efficiency critical |
|
||||
|
||||
## Supported Models
|
||||
|
||||
Unsloth supports many popular models including:
|
||||
- **Text LLMs**: Llama 3/3.1/3.2/3.3, Qwen 2.5/3, Mistral, Phi-4, Gemma 2/3, LFM2/2.5
|
||||
- **Vision LLMs**: Qwen3-VL, Gemma 3, Llama 3.2 Vision, Pixtral
|
||||
|
||||
Use Unsloth's pre-optimized model variants when available:
|
||||
```python
|
||||
# Unsloth-optimized models load faster and use less memory
|
||||
model_id = "unsloth/LFM2.5-1.2B-Instruct" # 4-bit quantized
|
||||
model_id = "unsloth/gemma-3-4b-pt" # Vision model
|
||||
model_id = "unsloth/Qwen3-VL-8B-Instruct" # Vision model
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
```python
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "unsloth",
|
||||
# "trl",
|
||||
# "datasets",
|
||||
# "trackio",
|
||||
# ]
|
||||
# ///
|
||||
```
|
||||
|
||||
## Basic Usage: Text LLM
|
||||
|
||||
```python
|
||||
from unsloth import FastLanguageModel
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model with Unsloth optimizations
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="LiquidAI/LFM2.5-1.2B-Instruct",
|
||||
max_seq_length=4096,
|
||||
)
|
||||
|
||||
# Add LoRA adapters
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=16,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "in_proj", "w1", "w2", "w3"],
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
)
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
# Train with TRL
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset,
|
||||
args=SFTConfig(
|
||||
output_dir="./output",
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
max_steps=500,
|
||||
learning_rate=2e-4,
|
||||
report_to="trackio",
|
||||
),
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## LFM2.5 Specific Settings
|
||||
|
||||
For LFM2.5 inference, use these recommended generation parameters:
|
||||
|
||||
**Instruct models:**
|
||||
```python
|
||||
temperature = 0.1
|
||||
top_k = 50
|
||||
top_p = 0.1
|
||||
repetition_penalty = 1.05
|
||||
```
|
||||
|
||||
**Thinking models:**
|
||||
```python
|
||||
temperature = 0.05
|
||||
top_k = 50
|
||||
repetition_penalty = 1.05
|
||||
```
|
||||
|
||||
## Vision-Language Models (VLMs)
|
||||
|
||||
Unsloth provides specialized support for VLMs with `FastVisionModel`:
|
||||
|
||||
```python
|
||||
from unsloth import FastVisionModel, get_chat_template
|
||||
from unsloth.trainer import UnslothVisionDataCollator
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load VLM with Unsloth
|
||||
model, processor = FastVisionModel.from_pretrained(
|
||||
"unsloth/gemma-3-4b-pt", # or "unsloth/Qwen3-VL-8B-Instruct"
|
||||
load_in_4bit=True,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
)
|
||||
|
||||
# Add LoRA for all modalities
|
||||
model = FastVisionModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers=True, # Train vision encoder
|
||||
finetune_language_layers=True, # Train language model
|
||||
finetune_attention_modules=True, # Train attention
|
||||
finetune_mlp_modules=True, # Train MLPs
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
)
|
||||
|
||||
# Apply chat template (required for base models)
|
||||
processor = get_chat_template(processor, "gemma-3")
|
||||
|
||||
# Load VLM dataset (with images and messages)
|
||||
dataset = load_dataset("your-vlm-dataset", split="train", streaming=True)
|
||||
|
||||
# Enable training mode
|
||||
FastVisionModel.for_training(model)
|
||||
|
||||
# Train with VLM-specific collator
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
train_dataset=dataset,
|
||||
processing_class=processor.tokenizer,
|
||||
data_collator=UnslothVisionDataCollator(model, processor),
|
||||
args=SFTConfig(
|
||||
output_dir="./vlm-output",
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
max_steps=500,
|
||||
learning_rate=2e-4,
|
||||
# VLM-specific settings
|
||||
remove_unused_columns=False,
|
||||
dataset_text_field="",
|
||||
dataset_kwargs={"skip_prepare_dataset": True},
|
||||
report_to="trackio",
|
||||
),
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Key Differences from Standard TRL
|
||||
|
||||
| Aspect | Standard TRL | Unsloth |
|
||||
|--------|--------------|---------|
|
||||
| Model loading | `AutoModelForCausalLM.from_pretrained()` | `FastLanguageModel.from_pretrained()` |
|
||||
| LoRA setup | `PeftModel` / `LoraConfig` | `FastLanguageModel.get_peft_model()` |
|
||||
| VLM loading | Limited support | `FastVisionModel.from_pretrained()` |
|
||||
| VLM collator | Manual | `UnslothVisionDataCollator` |
|
||||
| Memory usage | Standard | ~60% less |
|
||||
| Training speed | Standard | ~2x faster |
|
||||
|
||||
## VLM Dataset Format
|
||||
|
||||
VLM datasets should have:
|
||||
- `images`: List of PIL images or image paths
|
||||
- `messages`: Conversation format with image references
|
||||
|
||||
```python
|
||||
{
|
||||
"images": [<PIL.Image>, ...],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "Describe this image"}
|
||||
]},
|
||||
{"role": "assistant", "content": "This image shows..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Streaming Datasets
|
||||
|
||||
For large VLM datasets, use streaming to avoid disk space issues:
|
||||
|
||||
```python
|
||||
dataset = load_dataset(
|
||||
"your-vlm-dataset",
|
||||
split="train",
|
||||
streaming=True, # Stream from Hub
|
||||
)
|
||||
|
||||
# Must use max_steps with streaming (no epoch-based training)
|
||||
SFTConfig(max_steps=500, ...)
|
||||
```
|
||||
|
||||
## Saving Models
|
||||
|
||||
### Save LoRA Adapter
|
||||
|
||||
```python
|
||||
model.save_pretrained("./adapter")
|
||||
processor.save_pretrained("./adapter")
|
||||
|
||||
# Push to Hub
|
||||
model.push_to_hub("username/my-vlm-adapter")
|
||||
processor.push_to_hub("username/my-vlm-adapter")
|
||||
```
|
||||
|
||||
### Merge and Save Full Model
|
||||
|
||||
```python
|
||||
# Merge LoRA weights into base model
|
||||
model = model.merge_and_unload()
|
||||
|
||||
# Save merged model
|
||||
model.save_pretrained("./merged")
|
||||
tokenizer.save_pretrained("./merged")
|
||||
```
|
||||
|
||||
### Convert to GGUF
|
||||
|
||||
Unsloth models can be converted to GGUF for llama.cpp/Ollama:
|
||||
|
||||
```python
|
||||
# Save in 16-bit for GGUF conversion
|
||||
model.save_pretrained_gguf("./gguf", tokenizer, quantization_method="f16")
|
||||
|
||||
# Or directly quantize
|
||||
model.save_pretrained_gguf("./gguf", tokenizer, quantization_method="q4_k_m")
|
||||
```
|
||||
|
||||
## Qwen3-VL Specific Settings
|
||||
|
||||
For Qwen3-VL models, use these recommended settings:
|
||||
|
||||
**Instruct models:**
|
||||
```python
|
||||
temperature = 0.7
|
||||
top_p = 0.8
|
||||
presence_penalty = 1.5
|
||||
```
|
||||
|
||||
**Thinking models:**
|
||||
```python
|
||||
temperature = 1.0
|
||||
top_p = 0.95
|
||||
presence_penalty = 0.0
|
||||
```
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
| Model | Min VRAM (Unsloth 4-bit) | Recommended GPU |
|
||||
|-------|--------------------------|-----------------|
|
||||
| 2B-4B | 8GB | T4, L4 |
|
||||
| 7B-8B | 16GB | A10G, L4x4 |
|
||||
| 13B | 24GB | A10G-large |
|
||||
| 30B+ | 48GB+ | A100 |
|
||||
|
||||
## Example: Full VLM Training Script
|
||||
|
||||
See `scripts/unsloth_sft_example.py` for a complete production-ready example that includes:
|
||||
- Unsloth VLM setup
|
||||
- Streaming dataset support
|
||||
- Trackio monitoring
|
||||
- Hub push
|
||||
- CLI arguments
|
||||
|
||||
Run locally:
|
||||
```bash
|
||||
uv run scripts/unsloth_sft_example.py \
|
||||
--dataset trl-lib/Capybara \
|
||||
--max-steps 500 \
|
||||
--output-repo username/my-model
|
||||
```
|
||||
|
||||
Run on HF Jobs:
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "<script content>",
|
||||
"flavor": "a10g-large",
|
||||
"timeout": "2h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
})
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
- `scripts/unsloth_sft_example.py` - Complete text LLM training example
|
||||
- [Unsloth Documentation](https://unsloth.ai/docs)
|
||||
- [LFM2.5 Guide](https://unsloth.ai/docs/models/tutorials/lfm2.5)
|
||||
- [Qwen3-VL Guide](https://unsloth.ai/docs/models/qwen3-vl-how-to-run-and-fine-tune)
|
||||
- [Unsloth GitHub](https://github.com/unslothai/unsloth)
|
||||
@@ -0,0 +1,424 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "transformers>=4.36.0",
|
||||
# "peft>=0.7.0",
|
||||
# "torch>=2.0.0",
|
||||
# "accelerate>=0.24.0",
|
||||
# "huggingface_hub>=0.20.0",
|
||||
# "sentencepiece>=0.1.99",
|
||||
# "protobuf>=3.20.0",
|
||||
# "numpy",
|
||||
# "gguf",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
GGUF Conversion Script - Production Ready
|
||||
|
||||
This script converts a LoRA fine-tuned model to GGUF format for use with:
|
||||
- llama.cpp
|
||||
- Ollama
|
||||
- LM Studio
|
||||
- Other GGUF-compatible tools
|
||||
|
||||
PREREQUISITES (install these FIRST):
|
||||
- Ubuntu/Debian: sudo apt-get update && sudo apt-get install -y build-essential cmake
|
||||
- RHEL/CentOS: sudo yum groupinstall -y "Development Tools" && sudo yum install -y cmake
|
||||
- macOS: xcode-select --install && brew install cmake
|
||||
|
||||
Usage:
|
||||
Set environment variables:
|
||||
- ADAPTER_MODEL: Your fine-tuned model (e.g., "username/my-finetuned-model")
|
||||
- BASE_MODEL: Base model used for fine-tuning (e.g., "Qwen/Qwen2.5-0.5B")
|
||||
- OUTPUT_REPO: Where to upload GGUF files (e.g., "username/my-model-gguf")
|
||||
- HF_USERNAME: Your Hugging Face username (optional, for README)
|
||||
|
||||
Dependencies: All required packages are declared in PEP 723 header above.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import PeftModel
|
||||
from huggingface_hub import HfApi
|
||||
import subprocess
|
||||
|
||||
|
||||
def check_system_dependencies():
|
||||
"""Check if required system packages are available."""
|
||||
print("🔍 Checking system dependencies...")
|
||||
|
||||
# Check for git
|
||||
if subprocess.run(["which", "git"], capture_output=True).returncode != 0:
|
||||
print(" ❌ git is not installed. Please install it:")
|
||||
print(" Ubuntu/Debian: sudo apt-get install git")
|
||||
print(" RHEL/CentOS: sudo yum install git")
|
||||
print(" macOS: brew install git")
|
||||
return False
|
||||
|
||||
# Check for make or cmake
|
||||
has_make = subprocess.run(["which", "make"], capture_output=True).returncode == 0
|
||||
has_cmake = subprocess.run(["which", "cmake"], capture_output=True).returncode == 0
|
||||
|
||||
if not has_make and not has_cmake:
|
||||
print(" ❌ Neither make nor cmake found. Please install build tools:")
|
||||
print(" Ubuntu/Debian: sudo apt-get install build-essential cmake")
|
||||
print(" RHEL/CentOS: sudo yum groupinstall 'Development Tools' && sudo yum install cmake")
|
||||
print(" macOS: xcode-select --install && brew install cmake")
|
||||
return False
|
||||
|
||||
print(" ✅ System dependencies found")
|
||||
return True
|
||||
|
||||
|
||||
def run_command(cmd, description):
|
||||
"""Run a command with error handling."""
|
||||
print(f" {description}...")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
if result.stdout:
|
||||
print(f" {result.stdout[:200]}") # Show first 200 chars
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f" ❌ Command failed: {' '.join(cmd)}")
|
||||
if e.stdout:
|
||||
print(f" STDOUT: {e.stdout[:500]}")
|
||||
if e.stderr:
|
||||
print(f" STDERR: {e.stderr[:500]}")
|
||||
return False
|
||||
except FileNotFoundError:
|
||||
print(f" ❌ Command not found: {cmd[0]}")
|
||||
return False
|
||||
|
||||
|
||||
print("🔄 GGUF Conversion Script")
|
||||
print("=" * 60)
|
||||
|
||||
# Check system dependencies first
|
||||
if not check_system_dependencies():
|
||||
print("\n❌ Please install the missing system dependencies and try again.")
|
||||
sys.exit(1)
|
||||
|
||||
# Configuration from environment variables
|
||||
ADAPTER_MODEL = os.environ.get("ADAPTER_MODEL", "evalstate/qwen-capybara-medium")
|
||||
BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-0.5B")
|
||||
OUTPUT_REPO = os.environ.get("OUTPUT_REPO", "evalstate/qwen-capybara-medium-gguf")
|
||||
username = os.environ.get("HF_USERNAME", ADAPTER_MODEL.split('/')[0])
|
||||
|
||||
print(f"\n📦 Configuration:")
|
||||
print(f" Base model: {BASE_MODEL}")
|
||||
print(f" Adapter model: {ADAPTER_MODEL}")
|
||||
print(f" Output repo: {OUTPUT_REPO}")
|
||||
|
||||
# Step 1: Load base model and adapter
|
||||
print("\n🔧 Step 1: Loading base model and LoRA adapter...")
|
||||
print(" (This may take a few minutes)")
|
||||
|
||||
try:
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
BASE_MODEL,
|
||||
dtype=torch.float16,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
print(" ✅ Base model loaded")
|
||||
except Exception as e:
|
||||
print(f" ❌ Failed to load base model: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
# Load and merge adapter
|
||||
print(" Loading LoRA adapter...")
|
||||
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)
|
||||
print(" ✅ Adapter loaded")
|
||||
|
||||
print(" Merging adapter with base model...")
|
||||
merged_model = model.merge_and_unload()
|
||||
print(" ✅ Models merged!")
|
||||
except Exception as e:
|
||||
print(f" ❌ Failed to merge models: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_MODEL, trust_remote_code=True)
|
||||
print(" ✅ Tokenizer loaded")
|
||||
except Exception as e:
|
||||
print(f" ❌ Failed to load tokenizer: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Step 2: Save merged model temporarily
|
||||
print("\n💾 Step 2: Saving merged model...")
|
||||
merged_dir = "/tmp/merged_model"
|
||||
try:
|
||||
merged_model.save_pretrained(merged_dir, safe_serialization=True)
|
||||
tokenizer.save_pretrained(merged_dir)
|
||||
print(f" ✅ Merged model saved to {merged_dir}")
|
||||
except Exception as e:
|
||||
print(f" ❌ Failed to save merged model: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Step 3: Install llama.cpp for conversion
|
||||
print("\n📥 Step 3: Setting up llama.cpp for GGUF conversion...")
|
||||
|
||||
# Clone llama.cpp repository
|
||||
if not run_command(
|
||||
["git", "clone", "https://github.com/ggerganov/llama.cpp.git", "/tmp/llama.cpp"],
|
||||
"Cloning llama.cpp repository"
|
||||
):
|
||||
print(" Trying alternative clone method...")
|
||||
# Try shallow clone
|
||||
if not run_command(
|
||||
["git", "clone", "--depth", "1", "https://github.com/ggerganov/llama.cpp.git", "/tmp/llama.cpp"],
|
||||
"Cloning llama.cpp (shallow)"
|
||||
):
|
||||
sys.exit(1)
|
||||
|
||||
# Install Python dependencies
|
||||
print(" Installing Python dependencies...")
|
||||
if not run_command(
|
||||
["pip", "install", "-r", "/tmp/llama.cpp/requirements.txt"],
|
||||
"Installing llama.cpp requirements"
|
||||
):
|
||||
print(" ⚠️ Some requirements may already be installed")
|
||||
|
||||
if not run_command(
|
||||
["pip", "install", "sentencepiece", "protobuf"],
|
||||
"Installing tokenizer dependencies"
|
||||
):
|
||||
print(" ⚠️ Tokenizer dependencies may already be installed")
|
||||
|
||||
# Step 4: Convert to GGUF (FP16)
|
||||
print("\n🔄 Step 4: Converting to GGUF format (FP16)...")
|
||||
gguf_output_dir = "/tmp/gguf_output"
|
||||
os.makedirs(gguf_output_dir, exist_ok=True)
|
||||
|
||||
convert_script = "/tmp/llama.cpp/convert_hf_to_gguf.py"
|
||||
model_name = ADAPTER_MODEL.split('/')[-1]
|
||||
gguf_file = f"{gguf_output_dir}/{model_name}-f16.gguf"
|
||||
|
||||
print(f" Running conversion...")
|
||||
if not run_command(
|
||||
[
|
||||
sys.executable, convert_script,
|
||||
merged_dir,
|
||||
"--outfile", gguf_file,
|
||||
"--outtype", "f16"
|
||||
],
|
||||
f"Converting to FP16"
|
||||
):
|
||||
print(" ❌ Conversion failed!")
|
||||
sys.exit(1)
|
||||
|
||||
print(f" ✅ FP16 GGUF created: {gguf_file}")
|
||||
|
||||
# Step 5: Quantize to different formats
|
||||
print("\n⚙️ Step 5: Creating quantized versions...")
|
||||
|
||||
# Build quantize tool using CMake (more reliable than make)
|
||||
print(" Building quantize tool with CMake...")
|
||||
os.makedirs("/tmp/llama.cpp/build", exist_ok=True)
|
||||
|
||||
# Configure with CMake
|
||||
if not run_command(
|
||||
["cmake", "-B", "/tmp/llama.cpp/build", "-S", "/tmp/llama.cpp",
|
||||
"-DGGML_CUDA=OFF"],
|
||||
"Configuring with CMake"
|
||||
):
|
||||
print(" ❌ CMake configuration failed")
|
||||
sys.exit(1)
|
||||
|
||||
# Build just the quantize tool
|
||||
if not run_command(
|
||||
["cmake", "--build", "/tmp/llama.cpp/build", "--target", "llama-quantize", "-j", "4"],
|
||||
"Building llama-quantize"
|
||||
):
|
||||
print(" ❌ Build failed!")
|
||||
sys.exit(1)
|
||||
|
||||
print(" ✅ Quantize tool built")
|
||||
|
||||
# Use the CMake build output path
|
||||
quantize_bin = "/tmp/llama.cpp/build/bin/llama-quantize"
|
||||
|
||||
# Common quantization formats
|
||||
quant_formats = [
|
||||
("Q4_K_M", "4-bit, medium quality (recommended)"),
|
||||
("Q5_K_M", "5-bit, higher quality"),
|
||||
("Q8_0", "8-bit, very high quality"),
|
||||
]
|
||||
|
||||
quantized_files = []
|
||||
for quant_type, description in quant_formats:
|
||||
print(f" Creating {quant_type} quantization ({description})...")
|
||||
quant_file = f"{gguf_output_dir}/{model_name}-{quant_type.lower()}.gguf"
|
||||
|
||||
if not run_command(
|
||||
[quantize_bin, gguf_file, quant_file, quant_type],
|
||||
f"Quantizing to {quant_type}"
|
||||
):
|
||||
print(f" ⚠️ Skipping {quant_type} due to error")
|
||||
continue
|
||||
|
||||
quantized_files.append((quant_file, quant_type))
|
||||
|
||||
# Get file size
|
||||
size_mb = os.path.getsize(quant_file) / (1024 * 1024)
|
||||
print(f" ✅ {quant_type}: {size_mb:.1f} MB")
|
||||
|
||||
if not quantized_files:
|
||||
print(" ❌ No quantized versions were created successfully")
|
||||
sys.exit(1)
|
||||
|
||||
# Step 6: Upload to Hub
|
||||
print("\n☁️ Step 6: Uploading to Hugging Face Hub...")
|
||||
api = HfApi()
|
||||
|
||||
# Create repo
|
||||
print(f" Creating repository: {OUTPUT_REPO}")
|
||||
try:
|
||||
api.create_repo(repo_id=OUTPUT_REPO, repo_type="model", exist_ok=True)
|
||||
print(" ✅ Repository ready")
|
||||
except Exception as e:
|
||||
print(f" ℹ️ Repository may already exist: {e}")
|
||||
|
||||
# Upload FP16 version
|
||||
print(" Uploading FP16 GGUF...")
|
||||
try:
|
||||
api.upload_file(
|
||||
path_or_fileobj=gguf_file,
|
||||
path_in_repo=f"{model_name}-f16.gguf",
|
||||
repo_id=OUTPUT_REPO,
|
||||
)
|
||||
print(" ✅ FP16 uploaded")
|
||||
except Exception as e:
|
||||
print(f" ❌ Upload failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Upload quantized versions
|
||||
for quant_file, quant_type in quantized_files:
|
||||
print(f" Uploading {quant_type}...")
|
||||
try:
|
||||
api.upload_file(
|
||||
path_or_fileobj=quant_file,
|
||||
path_in_repo=f"{model_name}-{quant_type.lower()}.gguf",
|
||||
repo_id=OUTPUT_REPO,
|
||||
)
|
||||
print(f" ✅ {quant_type} uploaded")
|
||||
except Exception as e:
|
||||
print(f" ❌ Upload failed for {quant_type}: {e}")
|
||||
continue
|
||||
|
||||
# Create README
|
||||
print("\n📝 Creating README...")
|
||||
readme_content = f"""---
|
||||
base_model: {BASE_MODEL}
|
||||
tags:
|
||||
- gguf
|
||||
- llama.cpp
|
||||
- quantized
|
||||
- trl
|
||||
- sft
|
||||
---
|
||||
|
||||
# {OUTPUT_REPO.split('/')[-1]}
|
||||
|
||||
This is a GGUF conversion of [{ADAPTER_MODEL}](https://huggingface.co/{ADAPTER_MODEL}), which is a LoRA fine-tuned version of [{BASE_MODEL}](https://huggingface.co/{BASE_MODEL}).
|
||||
|
||||
## Model Details
|
||||
|
||||
- **Base Model:** {BASE_MODEL}
|
||||
- **Fine-tuned Model:** {ADAPTER_MODEL}
|
||||
- **Training:** Supervised Fine-Tuning (SFT) with TRL
|
||||
- **Format:** GGUF (for llama.cpp, Ollama, LM Studio, etc.)
|
||||
|
||||
## Available Quantizations
|
||||
|
||||
| File | Quant | Size | Description | Use Case |
|
||||
|------|-------|------|-------------|----------|
|
||||
| {model_name}-f16.gguf | F16 | ~1GB | Full precision | Best quality, slower |
|
||||
| {model_name}-q8_0.gguf | Q8_0 | ~500MB | 8-bit | High quality |
|
||||
| {model_name}-q5_k_m.gguf | Q5_K_M | ~350MB | 5-bit medium | Good quality, smaller |
|
||||
| {model_name}-q4_k_m.gguf | Q4_K_M | ~300MB | 4-bit medium | Recommended - good balance |
|
||||
|
||||
## Usage
|
||||
|
||||
### With llama.cpp
|
||||
|
||||
```bash
|
||||
# Download model
|
||||
hf download {OUTPUT_REPO} {model_name}-q4_k_m.gguf
|
||||
|
||||
# Run with llama.cpp
|
||||
./llama-cli -m {model_name}-q4_k_m.gguf -p "Your prompt here"
|
||||
```
|
||||
|
||||
### With Ollama
|
||||
|
||||
1. Create a `Modelfile`:
|
||||
```
|
||||
FROM ./{model_name}-q4_k_m.gguf
|
||||
```
|
||||
|
||||
2. Create the model:
|
||||
```bash
|
||||
ollama create my-model -f Modelfile
|
||||
ollama run my-model
|
||||
```
|
||||
|
||||
### With LM Studio
|
||||
|
||||
1. Download the `.gguf` file
|
||||
2. Import into LM Studio
|
||||
3. Start chatting!
|
||||
|
||||
## License
|
||||
|
||||
Inherits the license from the base model: {BASE_MODEL}
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{{{OUTPUT_REPO.split('/')[-1].replace('-', '_')},
|
||||
author = {{{username}}},
|
||||
title = {{{OUTPUT_REPO.split('/')[-1]}}},
|
||||
year = {{2025}},
|
||||
publisher = {{Hugging Face}},
|
||||
url = {{https://huggingface.co/{OUTPUT_REPO}}}
|
||||
}}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
*Converted to GGUF format using llama.cpp*
|
||||
"""
|
||||
|
||||
try:
|
||||
api.upload_file(
|
||||
path_or_fileobj=readme_content.encode(),
|
||||
path_in_repo="README.md",
|
||||
repo_id=OUTPUT_REPO,
|
||||
)
|
||||
print(" ✅ README uploaded")
|
||||
except Exception as e:
|
||||
print(f" ❌ README upload failed: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ GGUF Conversion Complete!")
|
||||
print(f"📦 Repository: https://huggingface.co/{OUTPUT_REPO}")
|
||||
print(f"\n📥 Download with:")
|
||||
print(f" hf download {OUTPUT_REPO} {model_name}-q4_k_m.gguf")
|
||||
print(f"\n🚀 Use with Ollama:")
|
||||
print(" 1. Download the GGUF file")
|
||||
print(f" 2. Create Modelfile: FROM ./{model_name}-q4_k_m.gguf")
|
||||
print(" 3. ollama create my-model -f Modelfile")
|
||||
print(" 4. ollama run my-model")
|
||||
print("=" * 60)
|
||||
@@ -0,0 +1,417 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = []
|
||||
# ///
|
||||
"""
|
||||
Dataset Format Inspector for TRL Training (LLM-Optimized Output)
|
||||
|
||||
Inspects Hugging Face datasets to determine TRL training compatibility.
|
||||
Uses Datasets Server API for instant results - no dataset download needed!
|
||||
|
||||
ULTRA-EFFICIENT: Uses HF Datasets Server API - completes in <2 seconds.
|
||||
|
||||
Usage with HF Jobs:
|
||||
hf_jobs("uv", {
|
||||
"script": "https://huggingface.co/datasets/evalstate/trl-helpers/raw/main/dataset_inspector.py",
|
||||
"script_args": ["--dataset", "your/dataset", "--split", "train"]
|
||||
})
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import json
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Inspect dataset format for TRL training")
|
||||
parser.add_argument("--dataset", type=str, required=True, help="Dataset name")
|
||||
parser.add_argument("--split", type=str, default="train", help="Dataset split (default: train)")
|
||||
parser.add_argument("--config", type=str, default="default", help="Dataset config name (default: default)")
|
||||
parser.add_argument("--preview", type=int, default=150, help="Max chars per field preview")
|
||||
parser.add_argument("--samples", type=int, default=5, help="Number of samples to fetch (default: 5)")
|
||||
parser.add_argument("--json-output", action="store_true", help="Output as JSON")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def api_request(url: str) -> Dict:
|
||||
"""Make API request to Datasets Server"""
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=10) as response:
|
||||
return json.loads(response.read().decode())
|
||||
except urllib.error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
return None
|
||||
raise Exception(f"API request failed: {e.code} {e.reason}")
|
||||
except Exception as e:
|
||||
raise Exception(f"API request failed: {str(e)}")
|
||||
|
||||
|
||||
def get_splits(dataset: str) -> Dict:
|
||||
"""Get available splits for dataset"""
|
||||
url = f"https://datasets-server.huggingface.co/splits?dataset={urllib.parse.quote(dataset)}"
|
||||
return api_request(url)
|
||||
|
||||
|
||||
def get_rows(dataset: str, config: str, split: str, offset: int = 0, length: int = 5) -> Dict:
|
||||
"""Get rows from dataset"""
|
||||
url = f"https://datasets-server.huggingface.co/rows?dataset={urllib.parse.quote(dataset)}&config={config}&split={split}&offset={offset}&length={length}"
|
||||
return api_request(url)
|
||||
|
||||
|
||||
def find_columns(columns: List[str], patterns: List[str]) -> List[str]:
|
||||
"""Find columns matching patterns"""
|
||||
return [c for c in columns if any(p in c.lower() for p in patterns)]
|
||||
|
||||
|
||||
def check_sft_compatibility(columns: List[str]) -> Dict[str, Any]:
|
||||
"""Check SFT compatibility"""
|
||||
has_messages = "messages" in columns
|
||||
has_text = "text" in columns
|
||||
has_prompt_completion = "prompt" in columns and "completion" in columns
|
||||
|
||||
ready = has_messages or has_text or has_prompt_completion
|
||||
|
||||
possible_prompt = find_columns(columns, ["prompt", "instruction", "question", "input"])
|
||||
possible_response = find_columns(columns, ["response", "completion", "output", "answer"])
|
||||
|
||||
return {
|
||||
"ready": ready,
|
||||
"reason": "messages" if has_messages else "text" if has_text else "prompt+completion" if has_prompt_completion else None,
|
||||
"possible_prompt": possible_prompt[0] if possible_prompt else None,
|
||||
"possible_response": possible_response[0] if possible_response else None,
|
||||
"has_context": "context" in columns,
|
||||
}
|
||||
|
||||
|
||||
def check_dpo_compatibility(columns: List[str]) -> Dict[str, Any]:
|
||||
"""Check DPO compatibility"""
|
||||
has_standard = "prompt" in columns and "chosen" in columns and "rejected" in columns
|
||||
|
||||
possible_prompt = find_columns(columns, ["prompt", "instruction", "question", "input"])
|
||||
possible_chosen = find_columns(columns, ["chosen", "preferred", "winner"])
|
||||
possible_rejected = find_columns(columns, ["rejected", "dispreferred", "loser"])
|
||||
|
||||
can_map = bool(possible_prompt and possible_chosen and possible_rejected)
|
||||
|
||||
return {
|
||||
"ready": has_standard,
|
||||
"can_map": can_map,
|
||||
"prompt_col": possible_prompt[0] if possible_prompt else None,
|
||||
"chosen_col": possible_chosen[0] if possible_chosen else None,
|
||||
"rejected_col": possible_rejected[0] if possible_rejected else None,
|
||||
}
|
||||
|
||||
|
||||
def check_grpo_compatibility(columns: List[str]) -> Dict[str, Any]:
|
||||
"""Check GRPO compatibility"""
|
||||
has_prompt = "prompt" in columns
|
||||
has_no_responses = "chosen" not in columns and "rejected" not in columns
|
||||
|
||||
possible_prompt = find_columns(columns, ["prompt", "instruction", "question", "input"])
|
||||
|
||||
return {
|
||||
"ready": has_prompt and has_no_responses,
|
||||
"can_map": bool(possible_prompt) and has_no_responses,
|
||||
"prompt_col": possible_prompt[0] if possible_prompt else None,
|
||||
}
|
||||
|
||||
|
||||
def check_kto_compatibility(columns: List[str]) -> Dict[str, Any]:
|
||||
"""Check KTO compatibility"""
|
||||
return {"ready": "prompt" in columns and "completion" in columns and "label" in columns}
|
||||
|
||||
|
||||
def generate_mapping_code(method: str, info: Dict[str, Any]) -> str:
|
||||
"""Generate mapping code for a training method"""
|
||||
if method == "SFT":
|
||||
if info["ready"]:
|
||||
return None
|
||||
|
||||
prompt_col = info.get("possible_prompt")
|
||||
response_col = info.get("possible_response")
|
||||
has_context = info.get("has_context", False)
|
||||
|
||||
if not prompt_col:
|
||||
return None
|
||||
|
||||
if has_context and response_col:
|
||||
return f"""def format_for_sft(example):
|
||||
text = f"Instruction: {{example['{prompt_col}']}}\n\n"
|
||||
if example.get('context'):
|
||||
text += f"Context: {{example['context']}}\n\n"
|
||||
text += f"Response: {{example['{response_col}']}}"
|
||||
return {{'text': text}}
|
||||
|
||||
dataset = dataset.map(format_for_sft, remove_columns=dataset.column_names)"""
|
||||
elif response_col:
|
||||
return f"""def format_for_sft(example):
|
||||
return {{'text': f"{{example['{prompt_col}']}}\n\n{{example['{response_col}']}}}}
|
||||
|
||||
dataset = dataset.map(format_for_sft, remove_columns=dataset.column_names)"""
|
||||
else:
|
||||
return f"""def format_for_sft(example):
|
||||
return {{'text': example['{prompt_col}']}}
|
||||
|
||||
dataset = dataset.map(format_for_sft, remove_columns=dataset.column_names)"""
|
||||
|
||||
elif method == "DPO":
|
||||
if info["ready"] or not info["can_map"]:
|
||||
return None
|
||||
|
||||
return f"""def format_for_dpo(example):
|
||||
return {{
|
||||
'prompt': example['{info['prompt_col']}'],
|
||||
'chosen': example['{info['chosen_col']}'],
|
||||
'rejected': example['{info['rejected_col']}'],
|
||||
}}
|
||||
|
||||
dataset = dataset.map(format_for_dpo, remove_columns=dataset.column_names)"""
|
||||
|
||||
elif method == "GRPO":
|
||||
if info["ready"] or not info["can_map"]:
|
||||
return None
|
||||
|
||||
return f"""def format_for_grpo(example):
|
||||
return {{'prompt': example['{info['prompt_col']}']}}
|
||||
|
||||
dataset = dataset.map(format_for_grpo, remove_columns=dataset.column_names)"""
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def format_value_preview(value: Any, max_chars: int) -> str:
|
||||
"""Format value for preview"""
|
||||
if value is None:
|
||||
return "None"
|
||||
elif isinstance(value, str):
|
||||
return value[:max_chars] + ("..." if len(value) > max_chars else "")
|
||||
elif isinstance(value, list):
|
||||
if len(value) > 0 and isinstance(value[0], dict):
|
||||
return f"[{len(value)} items] Keys: {list(value[0].keys())}"
|
||||
preview = str(value)
|
||||
return preview[:max_chars] + ("..." if len(preview) > max_chars else "")
|
||||
else:
|
||||
preview = str(value)
|
||||
return preview[:max_chars] + ("..." if len(preview) > max_chars else "")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
print(f"Fetching dataset info via Datasets Server API...")
|
||||
|
||||
try:
|
||||
# Get splits info
|
||||
splits_data = get_splits(args.dataset)
|
||||
if not splits_data or "splits" not in splits_data:
|
||||
print(f"ERROR: Could not fetch splits for dataset '{args.dataset}'")
|
||||
print(f" Dataset may not exist or is not accessible via Datasets Server API")
|
||||
sys.exit(1)
|
||||
|
||||
# Find the right config
|
||||
available_configs = set()
|
||||
split_found = False
|
||||
config_to_use = args.config
|
||||
|
||||
for split_info in splits_data["splits"]:
|
||||
available_configs.add(split_info["config"])
|
||||
if split_info["config"] == args.config and split_info["split"] == args.split:
|
||||
split_found = True
|
||||
|
||||
# If default config not found, try first available
|
||||
if not split_found and available_configs:
|
||||
config_to_use = list(available_configs)[0]
|
||||
print(f"Config '{args.config}' not found, trying '{config_to_use}'...")
|
||||
|
||||
# Get rows
|
||||
rows_data = get_rows(args.dataset, config_to_use, args.split, offset=0, length=args.samples)
|
||||
|
||||
if not rows_data or "rows" not in rows_data:
|
||||
print(f"ERROR: Could not fetch rows for dataset '{args.dataset}'")
|
||||
print(f" Split '{args.split}' may not exist")
|
||||
print(f" Available configs: {', '.join(sorted(available_configs))}")
|
||||
sys.exit(1)
|
||||
|
||||
rows = rows_data["rows"]
|
||||
if not rows:
|
||||
print(f"ERROR: No rows found in split '{args.split}'")
|
||||
sys.exit(1)
|
||||
|
||||
# Extract column info from first row
|
||||
first_row = rows[0]["row"]
|
||||
columns = list(first_row.keys())
|
||||
features = rows_data.get("features", [])
|
||||
|
||||
# Get total count if available
|
||||
total_examples = "Unknown"
|
||||
for split_info in splits_data["splits"]:
|
||||
if split_info["config"] == config_to_use and split_info["split"] == args.split:
|
||||
total_examples = f"{split_info.get('num_examples', 'Unknown'):,}" if isinstance(split_info.get('num_examples'), int) else "Unknown"
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
# Run compatibility checks
|
||||
sft_info = check_sft_compatibility(columns)
|
||||
dpo_info = check_dpo_compatibility(columns)
|
||||
grpo_info = check_grpo_compatibility(columns)
|
||||
kto_info = check_kto_compatibility(columns)
|
||||
|
||||
# Determine recommended methods
|
||||
recommended = []
|
||||
if sft_info["ready"]:
|
||||
recommended.append("SFT")
|
||||
elif sft_info["possible_prompt"]:
|
||||
recommended.append("SFT (needs mapping)")
|
||||
|
||||
if dpo_info["ready"]:
|
||||
recommended.append("DPO")
|
||||
elif dpo_info["can_map"]:
|
||||
recommended.append("DPO (needs mapping)")
|
||||
|
||||
if grpo_info["ready"]:
|
||||
recommended.append("GRPO")
|
||||
elif grpo_info["can_map"]:
|
||||
recommended.append("GRPO (needs mapping)")
|
||||
|
||||
if kto_info["ready"]:
|
||||
recommended.append("KTO")
|
||||
|
||||
# JSON output mode
|
||||
if args.json_output:
|
||||
result = {
|
||||
"dataset": args.dataset,
|
||||
"config": config_to_use,
|
||||
"split": args.split,
|
||||
"total_examples": total_examples,
|
||||
"columns": columns,
|
||||
"features": [{"name": f["name"], "type": f["type"]} for f in features] if features else [],
|
||||
"compatibility": {
|
||||
"SFT": sft_info,
|
||||
"DPO": dpo_info,
|
||||
"GRPO": grpo_info,
|
||||
"KTO": kto_info,
|
||||
},
|
||||
"recommended_methods": recommended,
|
||||
}
|
||||
print(json.dumps(result, indent=2))
|
||||
sys.exit(0)
|
||||
|
||||
# Human-readable output optimized for LLM parsing
|
||||
print("=" * 80)
|
||||
print(f"DATASET INSPECTION RESULTS")
|
||||
print("=" * 80)
|
||||
|
||||
print(f"\nDataset: {args.dataset}")
|
||||
print(f"Config: {config_to_use}")
|
||||
print(f"Split: {args.split}")
|
||||
print(f"Total examples: {total_examples}")
|
||||
print(f"Samples fetched: {len(rows)}")
|
||||
|
||||
print(f"\n{'COLUMNS':-<80}")
|
||||
if features:
|
||||
for feature in features:
|
||||
print(f" {feature['name']}: {feature['type']}")
|
||||
else:
|
||||
for col in columns:
|
||||
print(f" {col}: (type info not available)")
|
||||
|
||||
print(f"\n{'EXAMPLE DATA':-<80}")
|
||||
example = first_row
|
||||
for col in columns:
|
||||
value = example.get(col)
|
||||
display = format_value_preview(value, args.preview)
|
||||
print(f"\n{col}:")
|
||||
print(f" {display}")
|
||||
|
||||
print(f"\n{'TRAINING METHOD COMPATIBILITY':-<80}")
|
||||
|
||||
# SFT
|
||||
print(f"\n[SFT] {'✓ READY' if sft_info['ready'] else '✗ NEEDS MAPPING'}")
|
||||
if sft_info["ready"]:
|
||||
print(f" Reason: Dataset has '{sft_info['reason']}' field")
|
||||
print(f" Action: Use directly with SFTTrainer")
|
||||
elif sft_info["possible_prompt"]:
|
||||
print(f" Detected: prompt='{sft_info['possible_prompt']}' response='{sft_info['possible_response']}'")
|
||||
print(f" Action: Apply mapping code (see below)")
|
||||
else:
|
||||
print(f" Status: Cannot determine mapping - manual inspection needed")
|
||||
|
||||
# DPO
|
||||
print(f"\n[DPO] {'✓ READY' if dpo_info['ready'] else '✗ NEEDS MAPPING' if dpo_info['can_map'] else '✗ INCOMPATIBLE'}")
|
||||
if dpo_info["ready"]:
|
||||
print(f" Reason: Dataset has 'prompt', 'chosen', 'rejected' fields")
|
||||
print(f" Action: Use directly with DPOTrainer")
|
||||
elif dpo_info["can_map"]:
|
||||
print(f" Detected: prompt='{dpo_info['prompt_col']}' chosen='{dpo_info['chosen_col']}' rejected='{dpo_info['rejected_col']}'")
|
||||
print(f" Action: Apply mapping code (see below)")
|
||||
else:
|
||||
print(f" Status: Missing required fields (prompt + chosen + rejected)")
|
||||
|
||||
# GRPO
|
||||
print(f"\n[GRPO] {'✓ READY' if grpo_info['ready'] else '✗ NEEDS MAPPING' if grpo_info['can_map'] else '✗ INCOMPATIBLE'}")
|
||||
if grpo_info["ready"]:
|
||||
print(f" Reason: Dataset has 'prompt' field")
|
||||
print(f" Action: Use directly with GRPOTrainer")
|
||||
elif grpo_info["can_map"]:
|
||||
print(f" Detected: prompt='{grpo_info['prompt_col']}'")
|
||||
print(f" Action: Apply mapping code (see below)")
|
||||
else:
|
||||
print(f" Status: Missing prompt field")
|
||||
|
||||
# KTO
|
||||
print(f"\n[KTO] {'✓ READY' if kto_info['ready'] else '✗ INCOMPATIBLE'}")
|
||||
if kto_info["ready"]:
|
||||
print(f" Reason: Dataset has 'prompt', 'completion', 'label' fields")
|
||||
print(f" Action: Use directly with KTOTrainer")
|
||||
else:
|
||||
print(f" Status: Missing required fields (prompt + completion + label)")
|
||||
|
||||
# Mapping code
|
||||
print(f"\n{'MAPPING CODE (if needed)':-<80}")
|
||||
|
||||
mapping_needed = False
|
||||
|
||||
sft_mapping = generate_mapping_code("SFT", sft_info)
|
||||
if sft_mapping:
|
||||
print(f"\n# For SFT Training:")
|
||||
print(sft_mapping)
|
||||
mapping_needed = True
|
||||
|
||||
dpo_mapping = generate_mapping_code("DPO", dpo_info)
|
||||
if dpo_mapping:
|
||||
print(f"\n# For DPO Training:")
|
||||
print(dpo_mapping)
|
||||
mapping_needed = True
|
||||
|
||||
grpo_mapping = generate_mapping_code("GRPO", grpo_info)
|
||||
if grpo_mapping:
|
||||
print(f"\n# For GRPO Training:")
|
||||
print(grpo_mapping)
|
||||
mapping_needed = True
|
||||
|
||||
if not mapping_needed:
|
||||
print("\nNo mapping needed - dataset is ready for training!")
|
||||
|
||||
print(f"\n{'SUMMARY':-<80}")
|
||||
print(f"Recommended training methods: {', '.join(recommended) if recommended else 'None (dataset needs formatting)'}")
|
||||
print(f"\nNote: Used Datasets Server API (instant, no download required)")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
print(f"ERROR: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
@@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = []
|
||||
# ///
|
||||
"""
|
||||
Estimate training time and cost for TRL jobs.
|
||||
|
||||
Usage with uv:
|
||||
uv run estimate_cost.py --model <model> --dataset <dataset> --hardware <flavor>
|
||||
|
||||
Example:
|
||||
uv run estimate_cost.py --model Qwen/Qwen2.5-0.5B --dataset trl-lib/Capybara --hardware a10g-large
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
# Hardware costs per hour (approximate)
|
||||
HARDWARE_COSTS = {
|
||||
"t4-small": 0.75,
|
||||
"t4-medium": 1.50,
|
||||
"l4x1": 2.50,
|
||||
"a10g-small": 3.50,
|
||||
"a10g-large": 5.00,
|
||||
"a10g-largex2": 10.00,
|
||||
"a10g-largex4": 20.00,
|
||||
"a100-large": 10.00,
|
||||
}
|
||||
|
||||
# Model sizes in billions of parameters
|
||||
MODEL_SIZES = {
|
||||
"0.5B": 0.5,
|
||||
"1.5B": 1.5,
|
||||
"3B": 3,
|
||||
"7B": 7,
|
||||
"13B": 13,
|
||||
}
|
||||
|
||||
def estimate_training_time(model_params, dataset_size, epochs, hardware):
|
||||
"""Estimate training time in hours."""
|
||||
# Rough estimates based on empirical observations
|
||||
# These are approximations and actual times will vary
|
||||
|
||||
base_time_per_1k_examples = 0.1 # hours for 1B model on a10g-large
|
||||
|
||||
# Adjust for model size
|
||||
time = base_time_per_1k_examples * model_params * (dataset_size / 1000) * epochs
|
||||
|
||||
# Adjust for hardware (relative to a10g-large baseline)
|
||||
hardware_multipliers = {
|
||||
"t4-small": 2.0,
|
||||
"t4-medium": 1.5,
|
||||
"l4x1": 1.2,
|
||||
"a10g-small": 1.3,
|
||||
"a10g-large": 1.0,
|
||||
"a10g-largex2": 0.6,
|
||||
"a10g-largex4": 0.4,
|
||||
"a100-large": 0.7,
|
||||
}
|
||||
|
||||
multiplier = hardware_multipliers.get(hardware, 1.0)
|
||||
time *= multiplier
|
||||
|
||||
return time
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Estimate training cost for TRL jobs")
|
||||
parser.add_argument("--model", required=True, help="Model name or size (e.g., 'Qwen/Qwen2.5-0.5B' or '0.5B')")
|
||||
parser.add_argument("--dataset", required=True, help="Dataset name")
|
||||
parser.add_argument("--hardware", required=True, choices=HARDWARE_COSTS.keys(), help="Hardware flavor")
|
||||
parser.add_argument("--dataset-size", type=int, help="Override dataset size (number of examples)")
|
||||
parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs")
|
||||
return parser.parse_args()
|
||||
|
||||
def extract_model_size(model_name):
|
||||
"""Extract model size from name or return parsed value."""
|
||||
for size_str, size_val in MODEL_SIZES.items():
|
||||
if size_str in model_name:
|
||||
return size_val
|
||||
|
||||
# Try to parse directly
|
||||
try:
|
||||
if "B" in model_name:
|
||||
return float(model_name.replace("B", ""))
|
||||
except:
|
||||
pass
|
||||
|
||||
return 1.0 # Default to 1B if can't determine
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Extract model parameters
|
||||
model_params = extract_model_size(args.model)
|
||||
print(f"📊 Model: {args.model} (~{model_params}B parameters)")
|
||||
|
||||
# Estimate dataset size (would need to load to get real size)
|
||||
if args.dataset_size:
|
||||
dataset_size = args.dataset_size
|
||||
else:
|
||||
# Common dataset sizes (approximations)
|
||||
dataset_sizes = {
|
||||
"trl-lib/Capybara": 16000,
|
||||
"Anthropic/hh-rlhf": 160000,
|
||||
}
|
||||
dataset_size = dataset_sizes.get(args.dataset, 10000)
|
||||
|
||||
print(f"📦 Dataset: {args.dataset} (~{dataset_size} examples)")
|
||||
print(f"🔄 Epochs: {args.epochs}")
|
||||
print(f"💻 Hardware: {args.hardware}")
|
||||
print()
|
||||
|
||||
# Estimate training time
|
||||
estimated_hours = estimate_training_time(model_params, dataset_size, args.epochs, args.hardware)
|
||||
estimated_cost = estimated_hours * HARDWARE_COSTS[args.hardware]
|
||||
|
||||
# Recommend timeout with buffer
|
||||
recommended_timeout_hours = estimated_hours * 1.3 # 30% buffer
|
||||
|
||||
print(f"⏱️ Estimated training time: {estimated_hours:.1f} hours")
|
||||
print(f"💰 Estimated cost: ${estimated_cost:.2f}")
|
||||
print(f"⏰ Recommended timeout: {recommended_timeout_hours:.1f}h (with 30% buffer)")
|
||||
print()
|
||||
|
||||
# Warnings and recommendations
|
||||
if estimated_hours > 4:
|
||||
print("⚠️ Long training time - consider:")
|
||||
print(" - Using faster hardware")
|
||||
print(" - Reducing epochs")
|
||||
print(" - Using a smaller dataset subset for testing")
|
||||
|
||||
if model_params >= 7 and args.hardware not in ["a10g-largex2", "a10g-largex4", "a100-large"]:
|
||||
print("⚠️ Large model - consider using:")
|
||||
print(" - Larger GPU (a100-large)")
|
||||
print(" - Multi-GPU setup (a10g-largex2 or a10g-largex4)")
|
||||
print(" - LoRA/PEFT for memory efficiency")
|
||||
|
||||
print()
|
||||
print("📋 Example job configuration:")
|
||||
print(f"""
|
||||
hf_jobs("uv", {{
|
||||
"script": "your_training_script.py",
|
||||
"flavor": "{args.hardware}",
|
||||
"timeout": "{recommended_timeout_hours:.0f}h",
|
||||
"secrets": {{"HF_TOKEN": "$HF_TOKEN"}}
|
||||
}})
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "trl>=0.12.0",
|
||||
# "transformers>=4.36.0",
|
||||
# "accelerate>=0.24.0",
|
||||
# "trackio",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Production-ready DPO training example for preference learning.
|
||||
|
||||
DPO (Direct Preference Optimization) trains models on preference pairs
|
||||
(chosen vs rejected responses) without requiring a reward model.
|
||||
|
||||
Usage with hf_jobs MCP tool:
|
||||
hf_jobs("uv", {
|
||||
"script": '''<paste this entire file>''',
|
||||
"flavor": "a10g-large",
|
||||
"timeout": "3h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"},
|
||||
})
|
||||
|
||||
Or submit the script content directly inline without saving to a file.
|
||||
"""
|
||||
|
||||
import trackio
|
||||
from datasets import load_dataset
|
||||
from trl import DPOTrainer, DPOConfig
|
||||
|
||||
|
||||
# Load preference dataset
|
||||
print("📦 Loading dataset...")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
print(f"✅ Dataset loaded: {len(dataset)} preference pairs")
|
||||
|
||||
# Create train/eval split
|
||||
print("🔀 Creating train/eval split...")
|
||||
dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
|
||||
train_dataset = dataset_split["train"]
|
||||
eval_dataset = dataset_split["test"]
|
||||
print(f" Train: {len(train_dataset)} pairs")
|
||||
print(f" Eval: {len(eval_dataset)} pairs")
|
||||
|
||||
# Training configuration
|
||||
config = DPOConfig(
|
||||
# CRITICAL: Hub settings
|
||||
output_dir="qwen-dpo-aligned",
|
||||
push_to_hub=True,
|
||||
hub_model_id="username/qwen-dpo-aligned",
|
||||
hub_strategy="every_save",
|
||||
|
||||
# DPO-specific parameters
|
||||
beta=0.1, # KL penalty coefficient (higher = stay closer to reference)
|
||||
|
||||
# Training parameters
|
||||
num_train_epochs=1, # DPO typically needs fewer epochs than SFT
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=5e-7, # DPO uses much lower LR than SFT
|
||||
# max_length=1024, # Default - only set if you need different sequence length
|
||||
|
||||
# Logging & checkpointing
|
||||
logging_steps=10,
|
||||
save_strategy="steps",
|
||||
save_steps=100,
|
||||
save_total_limit=2,
|
||||
|
||||
# Evaluation - IMPORTANT: Only enable if eval_dataset provided
|
||||
eval_strategy="steps",
|
||||
eval_steps=100,
|
||||
|
||||
# Optimization
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type="cosine",
|
||||
|
||||
# Monitoring
|
||||
report_to="trackio", # Integrate with Trackio
|
||||
project="meaningful_project_name", # project name for the training name (trackio)
|
||||
run_name="baseline-run", #Descriptive name for this training run
|
||||
|
||||
)
|
||||
|
||||
# Initialize and train
|
||||
# Note: DPO requires an instruct-tuned model as the base
|
||||
print("🎯 Initializing trainer...")
|
||||
trainer = DPOTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B-Instruct", # Use instruct model, not base model
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset, # CRITICAL: Must provide eval_dataset when eval_strategy is enabled
|
||||
args=config,
|
||||
)
|
||||
|
||||
print("🚀 Starting DPO training...")
|
||||
trainer.train()
|
||||
|
||||
print("💾 Pushing to Hub...")
|
||||
trainer.push_to_hub()
|
||||
|
||||
# Finish Trackio tracking
|
||||
trackio.finish()
|
||||
|
||||
print("✅ Complete! Model at: https://huggingface.co/username/qwen-dpo-aligned")
|
||||
print("📊 View metrics at: https://huggingface.co/spaces/username/trackio")
|
||||
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "trl>=0.12.0",
|
||||
# "transformers>=4.36.0",
|
||||
# "accelerate>=0.24.0",
|
||||
# "trackio",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Production-ready GRPO training example for online RL.
|
||||
|
||||
GRPO (Group Relative Policy Optimization) is an online RL method that
|
||||
optimizes relative to group performance. Best for tasks with automatic
|
||||
reward signals like code execution or math verification.
|
||||
|
||||
Usage with hf_jobs MCP tool:
|
||||
hf_jobs("uv", {
|
||||
"script": '''<paste this entire file>''',
|
||||
"flavor": "a10g-large",
|
||||
"timeout": "4h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"},
|
||||
})
|
||||
|
||||
Or submit the script content directly inline without saving to a file.
|
||||
|
||||
Note: For most GRPO use cases, the TRL maintained script is recommended:
|
||||
https://raw.githubusercontent.com/huggingface/trl/main/examples/scripts/grpo.py
|
||||
"""
|
||||
|
||||
import trackio
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
|
||||
# Load dataset (GRPO uses prompt-only format)
|
||||
dataset = load_dataset("trl-lib/math_shepherd", split="train")
|
||||
print(f"✅ Dataset loaded: {len(dataset)} prompts")
|
||||
|
||||
# Training configuration
|
||||
config = GRPOConfig(
|
||||
# CRITICAL: Hub settings
|
||||
output_dir="qwen-grpo-math",
|
||||
push_to_hub=True,
|
||||
hub_model_id="username/qwen-grpo-math",
|
||||
hub_strategy="every_save",
|
||||
|
||||
# Training parameters
|
||||
num_train_epochs=1,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=1e-6,
|
||||
|
||||
# Logging & checkpointing
|
||||
logging_steps=10,
|
||||
save_strategy="steps",
|
||||
save_steps=100,
|
||||
save_total_limit=2,
|
||||
|
||||
# Optimization
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type="cosine",
|
||||
|
||||
# Monitoring
|
||||
report_to="trackio", # Integrate with Trackio
|
||||
project="meaningful_project_name", # project name for the training name (trackio)
|
||||
run_name="baseline-run", #Descriptive name for this training run
|
||||
|
||||
)
|
||||
|
||||
# Initialize and train
|
||||
# Note: GRPO requires an instruct-tuned model as the base
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B-Instruct",
|
||||
train_dataset=dataset,
|
||||
args=config,
|
||||
)
|
||||
|
||||
print("🚀 Starting GRPO training...")
|
||||
trainer.train()
|
||||
|
||||
print("💾 Pushing to Hub...")
|
||||
trainer.push_to_hub()
|
||||
|
||||
|
||||
print("✅ Complete! Model at: https://huggingface.co/username/qwen-grpo-math")
|
||||
print("📊 View metrics at: https://huggingface.co/spaces/username/trackio")
|
||||
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "trl>=0.12.0",
|
||||
# "peft>=0.7.0",
|
||||
# "transformers>=4.36.0",
|
||||
# "accelerate>=0.24.0",
|
||||
# "trackio",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Production-ready SFT training example with all best practices.
|
||||
|
||||
This script demonstrates:
|
||||
- Trackio integration for real-time monitoring
|
||||
- LoRA/PEFT for efficient training
|
||||
- Proper Hub saving configuration
|
||||
- Train/eval split for monitoring
|
||||
- Checkpoint management
|
||||
- Optimized training parameters
|
||||
|
||||
Usage with hf_jobs MCP tool:
|
||||
hf_jobs("uv", {
|
||||
"script": '''<paste this entire file>''',
|
||||
"flavor": "a10g-large",
|
||||
"timeout": "3h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"},
|
||||
})
|
||||
|
||||
Or submit the script content directly inline without saving to a file.
|
||||
"""
|
||||
|
||||
import trackio
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
|
||||
# Load dataset
|
||||
print("📦 Loading dataset...")
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
print(f"✅ Dataset loaded: {len(dataset)} examples")
|
||||
|
||||
# Create train/eval split
|
||||
print("🔀 Creating train/eval split...")
|
||||
dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
|
||||
train_dataset = dataset_split["train"]
|
||||
eval_dataset = dataset_split["test"]
|
||||
print(f" Train: {len(train_dataset)} examples")
|
||||
print(f" Eval: {len(eval_dataset)} examples")
|
||||
|
||||
# Note: For memory-constrained demos, skip eval by using full dataset as train_dataset
|
||||
# and removing eval_dataset, eval_strategy, and eval_steps from config below
|
||||
|
||||
# Training configuration
|
||||
config = SFTConfig(
|
||||
# CRITICAL: Hub settings
|
||||
output_dir="qwen-capybara-sft",
|
||||
push_to_hub=True,
|
||||
hub_model_id="username/qwen-capybara-sft",
|
||||
hub_strategy="every_save", # Push checkpoints
|
||||
|
||||
# Training parameters
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-5,
|
||||
# max_length=1024, # Default - only set if you need different sequence length
|
||||
|
||||
# Logging & checkpointing
|
||||
logging_steps=10,
|
||||
save_strategy="steps",
|
||||
save_steps=100,
|
||||
save_total_limit=2,
|
||||
|
||||
# Evaluation - IMPORTANT: Only enable if eval_dataset provided
|
||||
eval_strategy="steps",
|
||||
eval_steps=100,
|
||||
|
||||
# Optimization
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type="cosine",
|
||||
|
||||
# Monitoring
|
||||
report_to="trackio", # Integrate with Trackio
|
||||
project="meaningful_project_name", # project name for the training name (trackio)
|
||||
run_name="baseline-run", #Descriptive name for this training run
|
||||
)
|
||||
|
||||
# LoRA configuration
|
||||
peft_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
)
|
||||
|
||||
# Initialize and train
|
||||
print("🎯 Initializing trainer...")
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset, # CRITICAL: Must provide eval_dataset when eval_strategy is enabled
|
||||
args=config,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
print("🚀 Starting training...")
|
||||
trainer.train()
|
||||
|
||||
print("💾 Pushing to Hub...")
|
||||
trainer.push_to_hub()
|
||||
|
||||
# Finish Trackio tracking
|
||||
trackio.finish()
|
||||
|
||||
print("✅ Complete! Model at: https://huggingface.co/username/qwen-capybara-sft")
|
||||
print("📊 View metrics at: https://huggingface.co/spaces/username/trackio")
|
||||
@@ -0,0 +1,512 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "unsloth",
|
||||
# "datasets",
|
||||
# "trl==0.22.2",
|
||||
# "huggingface_hub[hf_transfer]",
|
||||
# "trackio",
|
||||
# "tensorboard",
|
||||
# "transformers==4.57.3",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
Fine-tune LLMs using Unsloth optimizations for ~60% less VRAM and 2x faster training.
|
||||
|
||||
Supports epoch-based or step-based training with optional eval split.
|
||||
Default model: LFM2.5-1.2B-Instruct (Liquid Foundation Model).
|
||||
|
||||
Epoch-based training (recommended for full datasets):
|
||||
uv run unsloth_sft_example.py \
|
||||
--dataset mlabonne/FineTome-100k \
|
||||
--num-epochs 1 \
|
||||
--eval-split 0.2 \
|
||||
--output-repo your-username/model-finetuned
|
||||
|
||||
Run on HF Jobs (1 epoch with eval):
|
||||
hf jobs uv run unsloth_sft_example.py \
|
||||
--flavor a10g-small --secrets HF_TOKEN --timeout 4h \
|
||||
-- --dataset mlabonne/FineTome-100k \
|
||||
--num-epochs 1 \
|
||||
--eval-split 0.2 \
|
||||
--output-repo your-username/model-finetuned
|
||||
|
||||
Step-based training (for quick tests):
|
||||
uv run unsloth_sft_example.py \
|
||||
--dataset mlabonne/FineTome-100k \
|
||||
--max-steps 500 \
|
||||
--output-repo your-username/model-finetuned
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
# Force unbuffered output for HF Jobs logs
|
||||
sys.stdout.reconfigure(line_buffering=True)
|
||||
sys.stderr.reconfigure(line_buffering=True)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_cuda():
|
||||
"""Check CUDA availability and exit if not available."""
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
logger.error("CUDA is not available. This script requires a GPU.")
|
||||
logger.error("Run on a machine with a CUDA-capable GPU or use HF Jobs:")
|
||||
logger.error(
|
||||
" hf jobs uv run unsloth_sft_example.py --flavor a10g-small ..."
|
||||
)
|
||||
sys.exit(1)
|
||||
logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Fine-tune LLMs with Unsloth optimizations",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Quick test run
|
||||
uv run unsloth_sft_example.py \\
|
||||
--dataset mlabonne/FineTome-100k \\
|
||||
--max-steps 50 \\
|
||||
--output-repo username/model-test
|
||||
|
||||
# Full training with eval
|
||||
uv run unsloth_sft_example.py \\
|
||||
--dataset mlabonne/FineTome-100k \\
|
||||
--num-epochs 1 \\
|
||||
--eval-split 0.2 \\
|
||||
--output-repo username/model-finetuned
|
||||
|
||||
# With Trackio monitoring
|
||||
uv run unsloth_sft_example.py \\
|
||||
--dataset mlabonne/FineTome-100k \\
|
||||
--num-epochs 1 \\
|
||||
--output-repo username/model-finetuned \\
|
||||
--trackio-space username/trackio
|
||||
""",
|
||||
)
|
||||
|
||||
# Model and data
|
||||
parser.add_argument(
|
||||
"--base-model",
|
||||
default="LiquidAI/LFM2.5-1.2B-Instruct",
|
||||
help="Base model (default: LiquidAI/LFM2.5-1.2B-Instruct)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
required=True,
|
||||
help="Dataset in ShareGPT/conversation format (e.g., mlabonne/FineTome-100k)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-repo",
|
||||
required=True,
|
||||
help="HF Hub repo to push model to (e.g., 'username/model-finetuned')",
|
||||
)
|
||||
|
||||
# Training config
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Number of epochs (default: None). Use instead of --max-steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Training steps (default: None). Use for quick tests or streaming.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Per-device batch size (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient-accumulation",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Gradient accumulation steps (default: 4). Effective batch = batch-size * this",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate",
|
||||
type=float,
|
||||
default=2e-4,
|
||||
help="Learning rate (default: 2e-4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-seq-length",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Maximum sequence length (default: 2048)",
|
||||
)
|
||||
|
||||
# LoRA config
|
||||
parser.add_argument(
|
||||
"--lora-r",
|
||||
type=int,
|
||||
default=16,
|
||||
help="LoRA rank (default: 16). Higher = more capacity but more VRAM",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-alpha",
|
||||
type=int,
|
||||
default=16,
|
||||
help="LoRA alpha (default: 16). Same as r per Unsloth recommendation",
|
||||
)
|
||||
|
||||
# Logging
|
||||
parser.add_argument(
|
||||
"--trackio-space",
|
||||
default=None,
|
||||
help="HF Space for Trackio dashboard (e.g., 'username/trackio')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run-name",
|
||||
default=None,
|
||||
help="Custom run name for Trackio (default: auto-generated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-local",
|
||||
default="unsloth-output",
|
||||
help="Local directory to save model (default: unsloth-output)",
|
||||
)
|
||||
|
||||
# Evaluation and data control
|
||||
parser.add_argument(
|
||||
"--eval-split",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Fraction of data for evaluation (0.0-0.5). Default: 0.0 (no eval)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Limit samples (default: None = use all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=3407,
|
||||
help="Random seed for reproducibility (default: 3407)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merge-model",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Merge LoRA weights into base model before uploading (larger file, easier to use)",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Validate epochs/steps configuration
|
||||
if not args.num_epochs and not args.max_steps:
|
||||
args.num_epochs = 1
|
||||
logger.info("Using default --num-epochs=1")
|
||||
|
||||
# Determine training duration display
|
||||
if args.num_epochs:
|
||||
duration_str = f"{args.num_epochs} epoch(s)"
|
||||
else:
|
||||
duration_str = f"{args.max_steps} steps"
|
||||
|
||||
print("=" * 70)
|
||||
print("LLM Fine-tuning with Unsloth")
|
||||
print("=" * 70)
|
||||
print("\nConfiguration:")
|
||||
print(f" Base model: {args.base_model}")
|
||||
print(f" Dataset: {args.dataset}")
|
||||
print(f" Num samples: {args.num_samples or 'all'}")
|
||||
print(f" Eval split: {args.eval_split if args.eval_split > 0 else '(disabled)'}")
|
||||
print(f" Seed: {args.seed}")
|
||||
print(f" Training: {duration_str}")
|
||||
print(f" Batch size: {args.batch_size} x {args.gradient_accumulation} = {args.batch_size * args.gradient_accumulation}")
|
||||
print(f" Learning rate: {args.learning_rate}")
|
||||
print(f" LoRA rank: {args.lora_r}")
|
||||
print(f" Max seq length: {args.max_seq_length}")
|
||||
print(f" Output repo: {args.output_repo}")
|
||||
print(f" Trackio space: {args.trackio_space or '(not configured)'}")
|
||||
print()
|
||||
|
||||
# Check CUDA before heavy imports
|
||||
check_cuda()
|
||||
|
||||
# Enable fast transfers
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
# Set Trackio space if provided
|
||||
if args.trackio_space:
|
||||
os.environ["TRACKIO_SPACE_ID"] = args.trackio_space
|
||||
logger.info(f"Trackio dashboard: https://huggingface.co/spaces/{args.trackio_space}")
|
||||
|
||||
# Import heavy dependencies
|
||||
from unsloth import FastLanguageModel
|
||||
from unsloth.chat_templates import standardize_data_formats, train_on_responses_only
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from huggingface_hub import login
|
||||
|
||||
# Login to Hub
|
||||
token = os.environ.get("HF_TOKEN") or os.environ.get("hfjob")
|
||||
if token:
|
||||
login(token=token)
|
||||
logger.info("Logged in to Hugging Face Hub")
|
||||
else:
|
||||
logger.warning("HF_TOKEN not set - model upload may fail")
|
||||
|
||||
# 1. Load model
|
||||
print("\n[1/5] Loading model...")
|
||||
start = time.time()
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=args.base_model,
|
||||
max_seq_length=args.max_seq_length,
|
||||
load_in_4bit=False,
|
||||
load_in_8bit=False,
|
||||
load_in_16bit=True,
|
||||
full_finetuning=False,
|
||||
)
|
||||
|
||||
# Add LoRA adapters
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=args.lora_r,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "in_proj", "w1", "w2", "w3"],
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=args.seed,
|
||||
use_rslora=False,
|
||||
loftq_config=None,
|
||||
)
|
||||
print(f"Model loaded in {time.time() - start:.1f}s")
|
||||
|
||||
# 2. Load and prepare dataset
|
||||
print("\n[2/5] Loading dataset...")
|
||||
start = time.time()
|
||||
|
||||
dataset = load_dataset(args.dataset, split="train")
|
||||
print(f" Dataset has {len(dataset)} total samples")
|
||||
|
||||
if args.num_samples:
|
||||
dataset = dataset.select(range(min(args.num_samples, len(dataset))))
|
||||
print(f" Limited to {len(dataset)} samples")
|
||||
|
||||
# Auto-detect and normalize conversation column
|
||||
for col in ["messages", "conversations", "conversation"]:
|
||||
if col in dataset.column_names and isinstance(dataset[0][col], list):
|
||||
if col != "conversations":
|
||||
dataset = dataset.rename_column(col, "conversations")
|
||||
break
|
||||
dataset = standardize_data_formats(dataset)
|
||||
|
||||
# Apply chat template
|
||||
def formatting_prompts_func(examples):
|
||||
texts = tokenizer.apply_chat_template(
|
||||
examples["conversations"],
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
# Remove BOS token to avoid duplicates
|
||||
return {"text": [x.removeprefix(tokenizer.bos_token) for x in texts]}
|
||||
|
||||
dataset = dataset.map(formatting_prompts_func, batched=True)
|
||||
|
||||
# Split for evaluation if requested
|
||||
if args.eval_split > 0:
|
||||
split = dataset.train_test_split(test_size=args.eval_split, seed=args.seed)
|
||||
train_data = split["train"]
|
||||
eval_data = split["test"]
|
||||
print(f" Train: {len(train_data)} samples, Eval: {len(eval_data)} samples")
|
||||
else:
|
||||
train_data = dataset
|
||||
eval_data = None
|
||||
|
||||
print(f" Dataset ready in {time.time() - start:.1f}s")
|
||||
|
||||
# 3. Configure trainer
|
||||
print("\n[3/5] Configuring trainer...")
|
||||
|
||||
# Calculate steps per epoch for logging/eval intervals
|
||||
effective_batch = args.batch_size * args.gradient_accumulation
|
||||
num_samples = len(train_data)
|
||||
steps_per_epoch = num_samples // effective_batch
|
||||
|
||||
# Determine run name and logging steps
|
||||
if args.run_name:
|
||||
run_name = args.run_name
|
||||
elif args.num_epochs:
|
||||
run_name = f"unsloth-sft-{args.num_epochs}ep"
|
||||
else:
|
||||
run_name = f"unsloth-sft-{args.max_steps}steps"
|
||||
|
||||
if args.num_epochs:
|
||||
logging_steps = max(1, steps_per_epoch // 10)
|
||||
save_steps = max(1, steps_per_epoch // 4)
|
||||
else:
|
||||
logging_steps = max(1, args.max_steps // 20)
|
||||
save_steps = max(1, args.max_steps // 4)
|
||||
|
||||
# Determine reporting backend
|
||||
if args.trackio_space:
|
||||
report_to = ["tensorboard", "trackio"]
|
||||
else:
|
||||
report_to = ["tensorboard"]
|
||||
|
||||
training_config = SFTConfig(
|
||||
output_dir=args.save_local,
|
||||
dataset_text_field="text",
|
||||
per_device_train_batch_size=args.batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation,
|
||||
warmup_steps=5,
|
||||
num_train_epochs=args.num_epochs if args.num_epochs else 1,
|
||||
max_steps=args.max_steps if args.max_steps else -1,
|
||||
learning_rate=args.learning_rate,
|
||||
logging_steps=logging_steps,
|
||||
optim="adamw_8bit",
|
||||
weight_decay=0.01,
|
||||
lr_scheduler_type="linear",
|
||||
seed=args.seed,
|
||||
max_length=args.max_seq_length,
|
||||
report_to=report_to,
|
||||
run_name=run_name,
|
||||
push_to_hub=True,
|
||||
hub_model_id=args.output_repo,
|
||||
save_steps=save_steps,
|
||||
save_total_limit=3,
|
||||
)
|
||||
|
||||
# Add evaluation config if eval is enabled
|
||||
if eval_data:
|
||||
if args.num_epochs:
|
||||
training_config.eval_strategy = "epoch"
|
||||
print(" Evaluation enabled: every epoch")
|
||||
else:
|
||||
training_config.eval_strategy = "steps"
|
||||
training_config.eval_steps = max(1, args.max_steps // 5)
|
||||
print(f" Evaluation enabled: every {training_config.eval_steps} steps")
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=train_data,
|
||||
eval_dataset=eval_data,
|
||||
args=training_config,
|
||||
)
|
||||
|
||||
# Train on responses only (mask user inputs)
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part="<|im_start|>user\n",
|
||||
response_part="<|im_start|>assistant\n",
|
||||
)
|
||||
|
||||
# 4. Train
|
||||
print(f"\n[4/5] Training for {duration_str}...")
|
||||
if args.num_epochs:
|
||||
print(f" (~{steps_per_epoch} steps/epoch, {int(steps_per_epoch * args.num_epochs)} total steps)")
|
||||
start = time.time()
|
||||
|
||||
train_result = trainer.train()
|
||||
|
||||
train_time = time.time() - start
|
||||
total_steps = train_result.metrics.get("train_steps", args.max_steps or steps_per_epoch * args.num_epochs)
|
||||
print(f"\nTraining completed in {train_time / 60:.1f} minutes")
|
||||
print(f" Speed: {total_steps / train_time:.2f} steps/s")
|
||||
|
||||
# Print training metrics
|
||||
train_loss = train_result.metrics.get("train_loss")
|
||||
if train_loss:
|
||||
print(f" Final train loss: {train_loss:.4f}")
|
||||
|
||||
# Print eval results if eval was enabled
|
||||
if eval_data:
|
||||
print("\nRunning final evaluation...")
|
||||
try:
|
||||
eval_results = trainer.evaluate()
|
||||
eval_loss = eval_results.get("eval_loss")
|
||||
if eval_loss:
|
||||
print(f" Final eval loss: {eval_loss:.4f}")
|
||||
if train_loss:
|
||||
ratio = eval_loss / train_loss
|
||||
if ratio > 1.5:
|
||||
print(f" Warning: Eval loss is {ratio:.1f}x train loss - possible overfitting")
|
||||
else:
|
||||
print(f" Eval/train ratio: {ratio:.2f} - model generalizes well")
|
||||
except Exception as e:
|
||||
print(f" Warning: Final evaluation failed: {e}")
|
||||
print(" Continuing to save model...")
|
||||
|
||||
# 5. Save and push
|
||||
print("\n[5/5] Saving model...")
|
||||
|
||||
if args.merge_model:
|
||||
print("Merging LoRA weights into base model...")
|
||||
print(f"\nPushing merged model to {args.output_repo}...")
|
||||
model.push_to_hub_merged(
|
||||
args.output_repo,
|
||||
tokenizer=tokenizer,
|
||||
save_method="merged_16bit",
|
||||
)
|
||||
print(f"Merged model available at: https://huggingface.co/{args.output_repo}")
|
||||
else:
|
||||
model.save_pretrained(args.save_local)
|
||||
tokenizer.save_pretrained(args.save_local)
|
||||
print(f"Saved locally to {args.save_local}/")
|
||||
|
||||
print(f"\nPushing adapter to {args.output_repo}...")
|
||||
model.push_to_hub(args.output_repo, tokenizer=tokenizer)
|
||||
print(f"Adapter available at: https://huggingface.co/{args.output_repo}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Done!")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) == 1:
|
||||
print("=" * 70)
|
||||
print("LLM Fine-tuning with Unsloth")
|
||||
print("=" * 70)
|
||||
print("\nFine-tune language models with optional train/eval split.")
|
||||
print("\nFeatures:")
|
||||
print(" - ~60% less VRAM with Unsloth optimizations")
|
||||
print(" - 2x faster training vs standard methods")
|
||||
print(" - Epoch-based or step-based training")
|
||||
print(" - Optional evaluation to detect overfitting")
|
||||
print(" - Trains only on assistant responses (masked user inputs)")
|
||||
print("\nEpoch-based training:")
|
||||
print("\n uv run unsloth_sft_example.py \\")
|
||||
print(" --dataset mlabonne/FineTome-100k \\")
|
||||
print(" --num-epochs 1 \\")
|
||||
print(" --eval-split 0.2 \\")
|
||||
print(" --output-repo your-username/model-finetuned")
|
||||
print("\nHF Jobs example:")
|
||||
print("\n hf jobs uv run unsloth_sft_example.py \\")
|
||||
print(" --flavor a10g-small --secrets HF_TOKEN --timeout 4h \\")
|
||||
print(" -- --dataset mlabonne/FineTome-100k \\")
|
||||
print(" --num-epochs 1 \\")
|
||||
print(" --eval-split 0.2 \\")
|
||||
print(" --output-repo your-username/model-finetuned")
|
||||
print("\nFor full help: uv run unsloth_sft_example.py --help")
|
||||
print("=" * 70)
|
||||
sys.exit(0)
|
||||
|
||||
main()
|
||||
@@ -1,11 +1,15 @@
|
||||
---
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/huggingface-paper-publisher"
|
||||
name: hugging-face-paper-publisher
|
||||
description: Publish and manage research papers on Hugging Face Hub. Supports creating paper pages, linking papers to models/datasets, claiming authorship, and generating professional markdown-based research articles.
|
||||
risk: unknown
|
||||
source: community
|
||||
---
|
||||
|
||||
# Overview
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when a user wants to publish, link, index, or manage research papers on the Hugging Face Hub.
|
||||
This skill provides comprehensive tools for AI engineers and researchers to publish, manage, and link research papers on the Hugging Face Hub. It streamlines the workflow from paper creation to publication, including integration with arXiv, model/dataset linking, and authorship management.
|
||||
|
||||
## Integration with HF Ecosystem
|
||||
@@ -19,6 +23,9 @@ This skill provides comprehensive tools for AI engineers and researchers to publ
|
||||
1.0.0
|
||||
|
||||
# Dependencies
|
||||
The included script uses PEP 723 inline dependencies. Prefer `uv run` over
|
||||
manual environment setup.
|
||||
|
||||
- huggingface_hub>=0.26.0
|
||||
- pyyaml>=6.0.3
|
||||
- requests>=2.32.5
|
||||
@@ -57,15 +64,15 @@ This skill provides comprehensive tools for AI engineers and researchers to publ
|
||||
The skill includes Python scripts in `scripts/` for paper publishing operations.
|
||||
|
||||
### Prerequisites
|
||||
- Install dependencies: `uv add huggingface_hub pyyaml requests markdown python-dotenv`
|
||||
- Run scripts with `uv run` (dependencies are resolved from the script header)
|
||||
- Set `HF_TOKEN` environment variable with Write-access token
|
||||
- Activate virtual environment: `source .venv/bin/activate`
|
||||
|
||||
> **All paths are relative to the directory containing this SKILL.md
|
||||
file.**
|
||||
> Before running any script, first `cd` to that directory or use the full
|
||||
path.
|
||||
|
||||
|
||||
### Method 1: Index Paper from arXiv
|
||||
|
||||
Add a paper to Hugging Face Paper Pages from arXiv.
|
||||
@@ -435,7 +442,7 @@ uv run scripts/paper_manager.py link \
|
||||
**Workflow 3: Update Model with Paper Reference**
|
||||
```bash
|
||||
# 1. Get current README
|
||||
huggingface-cli download username/model-name README.md
|
||||
hf download username/model-name README.md
|
||||
|
||||
# 2. Add paper link
|
||||
uv run scripts/paper_manager.py link \
|
||||
|
||||
@@ -0,0 +1,326 @@
|
||||
# Example Usage: HF Paper Publisher Skill
|
||||
|
||||
This document demonstrates common workflows for publishing research papers on Hugging Face Hub.
|
||||
|
||||
## Example 1: Index an Existing arXiv Paper
|
||||
|
||||
If you've already published a paper on arXiv and want to make it discoverable on Hugging Face:
|
||||
|
||||
```bash
|
||||
# Check if paper exists
|
||||
uv run scripts/paper_manager.py check --arxiv-id "2301.12345"
|
||||
|
||||
# Index the paper
|
||||
uv run scripts/paper_manager.py index --arxiv-id "2301.12345"
|
||||
|
||||
# Get paper information
|
||||
uv run scripts/paper_manager.py info --arxiv-id "2301.12345"
|
||||
```
|
||||
|
||||
Expected output:
|
||||
```json
|
||||
{
|
||||
"exists": true,
|
||||
"url": "https://huggingface.co/papers/2301.12345",
|
||||
"arxiv_id": "2301.12345",
|
||||
"arxiv_url": "https://arxiv.org/abs/2301.12345"
|
||||
}
|
||||
```
|
||||
|
||||
## Example 2: Link Paper to Your Model
|
||||
|
||||
After indexing a paper, link it to your model repository:
|
||||
|
||||
```bash
|
||||
# Link single paper
|
||||
uv run scripts/paper_manager.py link \
|
||||
--repo-id "username/my-awesome-model" \
|
||||
--repo-type "model" \
|
||||
--arxiv-id "2301.12345"
|
||||
|
||||
# Link multiple papers
|
||||
uv run scripts/paper_manager.py link \
|
||||
--repo-id "username/my-awesome-model" \
|
||||
--repo-type "model" \
|
||||
--arxiv-ids "2301.12345,2302.67890"
|
||||
```
|
||||
|
||||
This will:
|
||||
1. Download the model's README.md
|
||||
2. Add or update YAML frontmatter
|
||||
3. Insert paper references with links
|
||||
4. Upload the updated README
|
||||
5. Hub automatically creates `arxiv:2301.12345` tags
|
||||
|
||||
## Example 3: Link Paper to Dataset
|
||||
|
||||
Same process for datasets:
|
||||
|
||||
```bash
|
||||
uv run scripts/paper_manager.py link \
|
||||
--repo-id "username/my-dataset" \
|
||||
--repo-type "dataset" \
|
||||
--arxiv-id "2301.12345" \
|
||||
--citation "$(cat citation.bib)"
|
||||
```
|
||||
|
||||
## Example 4: Create a New Research Article
|
||||
|
||||
Generate a research paper from template:
|
||||
|
||||
```bash
|
||||
# Create with standard template
|
||||
uv run scripts/paper_manager.py create \
|
||||
--template "standard" \
|
||||
--title "Efficient Fine-Tuning of Large Language Models" \
|
||||
--authors "Jane Doe, John Smith" \
|
||||
--abstract "We propose a novel approach to fine-tuning..." \
|
||||
--output "paper.md"
|
||||
|
||||
# Create with modern template
|
||||
uv run scripts/paper_manager.py create \
|
||||
--template "modern" \
|
||||
--title "Vision Transformers for Medical Imaging" \
|
||||
--output "medical_vit_paper.md"
|
||||
|
||||
# Create ML experiment report
|
||||
uv run scripts/paper_manager.py create \
|
||||
--template "ml-report" \
|
||||
--title "BERT Fine-tuning Experiment Results" \
|
||||
--output "bert_experiment_report.md"
|
||||
```
|
||||
|
||||
## Example 5: Generate Citations
|
||||
|
||||
Get formatted citations for papers:
|
||||
|
||||
```bash
|
||||
# BibTeX format
|
||||
uv run scripts/paper_manager.py citation \
|
||||
--arxiv-id "2301.12345" \
|
||||
--format "bibtex"
|
||||
```
|
||||
|
||||
Output:
|
||||
```bibtex
|
||||
@article{arxiv2301_12345,
|
||||
title={Efficient Fine-Tuning of Large Language Models},
|
||||
author={Doe, Jane and Smith, John},
|
||||
journal={arXiv preprint arXiv:2301.12345},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
## Example 6: Complete Workflow - New Paper
|
||||
|
||||
Full workflow from paper creation to publication:
|
||||
|
||||
```bash
|
||||
# Step 1: Create research article
|
||||
uv run scripts/paper_manager.py create \
|
||||
--template "modern" \
|
||||
--title "Novel Architecture for Multimodal Learning" \
|
||||
--authors "Alice Chen, Bob Kumar" \
|
||||
--output "multimodal_paper.md"
|
||||
|
||||
# Step 2: Edit the paper (use your favorite editor)
|
||||
# vim multimodal_paper.md
|
||||
|
||||
# Step 3: Submit to arXiv (external process)
|
||||
# Upload to arxiv.org, receive arXiv ID: 2312.99999
|
||||
|
||||
# Step 4: Index on Hugging Face
|
||||
uv run scripts/paper_manager.py index --arxiv-id "2312.99999"
|
||||
|
||||
# Step 5: Link to your models/datasets
|
||||
uv run scripts/paper_manager.py link \
|
||||
--repo-id "alice/multimodal-model-v1" \
|
||||
--repo-type "model" \
|
||||
--arxiv-id "2312.99999"
|
||||
|
||||
uv run scripts/paper_manager.py link \
|
||||
--repo-id "alice/multimodal-dataset" \
|
||||
--repo-type "dataset" \
|
||||
--arxiv-id "2312.99999"
|
||||
|
||||
# Step 6: Generate citation for README
|
||||
uv run scripts/paper_manager.py citation \
|
||||
--arxiv-id "2312.99999" \
|
||||
--format "bibtex" > citation.bib
|
||||
```
|
||||
|
||||
## Example 7: Batch Link Papers
|
||||
|
||||
Link multiple papers to multiple repositories:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
# List of papers
|
||||
PAPERS=("2301.12345" "2302.67890" "2303.11111")
|
||||
|
||||
# List of models
|
||||
MODELS=("username/model-a" "username/model-b" "username/model-c")
|
||||
|
||||
# Link each paper to each model
|
||||
for paper in "${PAPERS[@]}"; do
|
||||
for model in "${MODELS[@]}"; do
|
||||
echo "Linking $paper to $model..."
|
||||
uv run scripts/paper_manager.py link \
|
||||
--repo-id "$model" \
|
||||
--repo-type "model" \
|
||||
--arxiv-id "$paper"
|
||||
done
|
||||
done
|
||||
```
|
||||
|
||||
## Example 8: Update Model Card with Paper Info
|
||||
|
||||
Get paper info and manually update model card:
|
||||
|
||||
```bash
|
||||
# Get paper information
|
||||
uv run scripts/paper_manager.py info \
|
||||
--arxiv-id "2301.12345" \
|
||||
--format "text" > paper_info.txt
|
||||
|
||||
# View the information
|
||||
cat paper_info.txt
|
||||
|
||||
# Manually incorporate into your model card or use the link command
|
||||
```
|
||||
|
||||
## Example 9: Search and Discover Papers
|
||||
|
||||
```bash
|
||||
# Search for papers (opens browser)
|
||||
uv run scripts/paper_manager.py search \
|
||||
--query "transformer attention mechanism"
|
||||
```
|
||||
|
||||
## Example 10: Working with tfrere's Template
|
||||
|
||||
This skill complements [tfrere's research article template](https://huggingface.co/spaces/tfrere/research-article-template):
|
||||
|
||||
```bash
|
||||
# 1. Use tfrere's Space to create a beautiful web-based paper
|
||||
# Visit: https://huggingface.co/spaces/tfrere/research-article-template
|
||||
|
||||
# 2. Export your paper content to markdown
|
||||
|
||||
# 3. Submit to arXiv
|
||||
|
||||
# 4. Use this skill to index and link
|
||||
uv run scripts/paper_manager.py index --arxiv-id "YOUR_ARXIV_ID"
|
||||
uv run scripts/paper_manager.py link \
|
||||
--repo-id "your-username/your-model" \
|
||||
--arxiv-id "YOUR_ARXIV_ID"
|
||||
```
|
||||
|
||||
## Example 11: Error Handling
|
||||
|
||||
```bash
|
||||
# Check if paper exists before linking
|
||||
if uv run scripts/paper_manager.py check --arxiv-id "2301.12345" | grep -q '"exists": true'; then
|
||||
echo "Paper exists, proceeding with link..."
|
||||
uv run scripts/paper_manager.py link \
|
||||
--repo-id "username/model" \
|
||||
--arxiv-id "2301.12345"
|
||||
else
|
||||
echo "Paper doesn't exist, indexing first..."
|
||||
uv run scripts/paper_manager.py index --arxiv-id "2301.12345"
|
||||
uv run scripts/paper_manager.py link \
|
||||
--repo-id "username/model" \
|
||||
--arxiv-id "2301.12345"
|
||||
fi
|
||||
```
|
||||
|
||||
## Example 12: CI/CD Integration
|
||||
|
||||
Add to your `.github/workflows/update-paper.yml`:
|
||||
|
||||
```yaml
|
||||
name: Update Paper Links
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
update:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Link paper to model
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
uv run scripts/paper_manager.py link \
|
||||
--repo-id "${{ github.repository_owner }}/model-name" \
|
||||
--repo-type "model" \
|
||||
--arxiv-id "2301.12345"
|
||||
```
|
||||
|
||||
## Tips and Best Practices
|
||||
|
||||
1. **Always check if paper exists** before indexing to avoid unnecessary operations
|
||||
2. **Use meaningful commit messages** when linking papers to repositories
|
||||
3. **Include full citations** in model cards for proper attribution
|
||||
4. **Link papers to all relevant artifacts** (models, datasets, spaces)
|
||||
5. **Generate BibTeX citations** for easy reference by others
|
||||
6. **Keep paper visibility updated** in your HF profile settings
|
||||
7. **Use templates consistently** within your research group
|
||||
8. **Version control your papers** alongside code
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Paper not found after indexing
|
||||
|
||||
```bash
|
||||
# Visit the URL directly to trigger indexing
|
||||
open "https://huggingface.co/papers/2301.12345"
|
||||
|
||||
# Wait a few seconds, then check again
|
||||
uv run scripts/paper_manager.py check --arxiv-id "2301.12345"
|
||||
```
|
||||
|
||||
### Permission denied when linking
|
||||
|
||||
```bash
|
||||
# Verify your token has write access
|
||||
echo $HF_TOKEN
|
||||
|
||||
# Set token if missing
|
||||
export HF_TOKEN="your_token_here"
|
||||
|
||||
# Or use .env file
|
||||
echo "HF_TOKEN=your_token_here" > .env
|
||||
```
|
||||
|
||||
### arXiv ID format issues
|
||||
|
||||
```bash
|
||||
# The script handles various formats:
|
||||
uv run scripts/paper_manager.py check --arxiv-id "2301.12345"
|
||||
uv run scripts/paper_manager.py check --arxiv-id "arxiv:2301.12345"
|
||||
uv run scripts/paper_manager.py check --arxiv-id "https://arxiv.org/abs/2301.12345"
|
||||
|
||||
# All are equivalent and will be normalized
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
- Explore the [Paper Pages documentation](https://huggingface.co/docs/hub/en/paper-pages)
|
||||
- Check out [tfrere's research template](https://huggingface.co/spaces/tfrere/research-article-template)
|
||||
- Browse [papers on HF](https://huggingface.co/papers)
|
||||
- Learn about [model cards](https://huggingface.co/docs/hub/en/model-cards)
|
||||
@@ -0,0 +1,216 @@
|
||||
# Quick Reference Guide
|
||||
|
||||
## Essential Commands
|
||||
|
||||
### Paper Indexing
|
||||
```bash
|
||||
# Index from arXiv
|
||||
uv run scripts/paper_manager.py index --arxiv-id "2301.12345"
|
||||
|
||||
# Check if exists
|
||||
uv run scripts/paper_manager.py check --arxiv-id "2301.12345"
|
||||
```
|
||||
|
||||
### Linking Papers
|
||||
```bash
|
||||
# Link to model
|
||||
uv run scripts/paper_manager.py link \
|
||||
--repo-id "username/model" \
|
||||
--repo-type "model" \
|
||||
--arxiv-id "2301.12345"
|
||||
|
||||
# Link to dataset
|
||||
uv run scripts/paper_manager.py link \
|
||||
--repo-id "username/dataset" \
|
||||
--repo-type "dataset" \
|
||||
--arxiv-id "2301.12345"
|
||||
|
||||
# Link multiple papers
|
||||
uv run scripts/paper_manager.py link \
|
||||
--repo-id "username/model" \
|
||||
--repo-type "model" \
|
||||
--arxiv-ids "2301.12345,2302.67890"
|
||||
```
|
||||
|
||||
### Creating Papers
|
||||
```bash
|
||||
# Standard template
|
||||
uv run scripts/paper_manager.py create \
|
||||
--template "standard" \
|
||||
--title "Paper Title" \
|
||||
--output "paper.md"
|
||||
|
||||
# Modern template
|
||||
uv run scripts/paper_manager.py create \
|
||||
--template "modern" \
|
||||
--title "Paper Title" \
|
||||
--authors "Author1, Author2" \
|
||||
--abstract "Abstract text" \
|
||||
--output "paper.md"
|
||||
|
||||
# ML Report
|
||||
uv run scripts/paper_manager.py create \
|
||||
--template "ml-report" \
|
||||
--title "Experiment Report" \
|
||||
--output "report.md"
|
||||
|
||||
# arXiv style
|
||||
uv run scripts/paper_manager.py create \
|
||||
--template "arxiv" \
|
||||
--title "Paper Title" \
|
||||
--output "paper.md"
|
||||
```
|
||||
|
||||
### Citations
|
||||
```bash
|
||||
# Generate BibTeX
|
||||
uv run scripts/paper_manager.py citation \
|
||||
--arxiv-id "2301.12345" \
|
||||
--format "bibtex"
|
||||
```
|
||||
|
||||
### Paper Info
|
||||
```bash
|
||||
# JSON format
|
||||
uv run scripts/paper_manager.py info \
|
||||
--arxiv-id "2301.12345" \
|
||||
--format "json"
|
||||
|
||||
# Text format
|
||||
uv run scripts/paper_manager.py info \
|
||||
--arxiv-id "2301.12345" \
|
||||
--format "text"
|
||||
```
|
||||
|
||||
## URL Formats
|
||||
|
||||
### Hugging Face Paper Pages
|
||||
- View paper: `https://huggingface.co/papers/{arxiv-id}`
|
||||
- Example: `https://huggingface.co/papers/2301.12345`
|
||||
|
||||
### arXiv
|
||||
- Abstract: `https://arxiv.org/abs/{arxiv-id}`
|
||||
- PDF: `https://arxiv.org/pdf/{arxiv-id}.pdf`
|
||||
- Example: `https://arxiv.org/abs/2301.12345`
|
||||
|
||||
## YAML Metadata Format
|
||||
|
||||
### Model Card
|
||||
```yaml
|
||||
---
|
||||
language:
|
||||
- en
|
||||
license: apache-2.0
|
||||
tags:
|
||||
- text-generation
|
||||
- transformers
|
||||
library_name: transformers
|
||||
---
|
||||
```
|
||||
|
||||
### Dataset Card
|
||||
```yaml
|
||||
---
|
||||
language:
|
||||
- en
|
||||
license: cc-by-4.0
|
||||
task_categories:
|
||||
- text-generation
|
||||
size_categories:
|
||||
- 10K<n<100K
|
||||
---
|
||||
```
|
||||
|
||||
## arXiv ID Formats
|
||||
|
||||
All these formats work:
|
||||
- `2301.12345`
|
||||
- `arxiv:2301.12345`
|
||||
- `https://arxiv.org/abs/2301.12345`
|
||||
- `https://arxiv.org/pdf/2301.12345.pdf`
|
||||
|
||||
## Environment Setup
|
||||
|
||||
### Set Token
|
||||
```bash
|
||||
export HF_TOKEN="your_token"
|
||||
```
|
||||
|
||||
### Or use .env file
|
||||
```bash
|
||||
echo "HF_TOKEN=your_token" > .env
|
||||
```
|
||||
|
||||
## Common Workflows
|
||||
|
||||
### 1. Index & Link
|
||||
```bash
|
||||
uv run scripts/paper_manager.py index --arxiv-id "2301.12345"
|
||||
uv run scripts/paper_manager.py link --repo-id "user/model" --arxiv-id "2301.12345"
|
||||
```
|
||||
|
||||
### 2. Create & Publish
|
||||
```bash
|
||||
uv run scripts/paper_manager.py create --template "modern" --title "Title" --output "paper.md"
|
||||
# Edit paper.md
|
||||
# Submit to arXiv → get ID
|
||||
uv run scripts/paper_manager.py index --arxiv-id "NEW_ID"
|
||||
uv run scripts/paper_manager.py link --repo-id "user/model" --arxiv-id "NEW_ID"
|
||||
```
|
||||
|
||||
### 3. Batch Link
|
||||
```bash
|
||||
for id in "2301.12345" "2302.67890"; do
|
||||
uv run scripts/paper_manager.py link --repo-id "user/model" --arxiv-id "$id"
|
||||
done
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Paper not found
|
||||
Visit `https://huggingface.co/papers/{arxiv-id}` to trigger indexing
|
||||
|
||||
### Permission denied
|
||||
Check `HF_TOKEN` is set and has write access
|
||||
|
||||
### arXiv API errors
|
||||
Wait a moment and retry - arXiv has rate limits
|
||||
|
||||
## Tips
|
||||
|
||||
1. Always check paper exists before linking
|
||||
2. Use templates for consistency
|
||||
3. Include full citations in model cards
|
||||
4. Link papers to all relevant artifacts
|
||||
5. Keep citations up to date
|
||||
|
||||
## Templates Available
|
||||
|
||||
- `standard` - Traditional academic paper
|
||||
- `modern` - Web-friendly format (Distill-style)
|
||||
- `arxiv` - arXiv journal format
|
||||
- `ml-report` - ML experiment documentation
|
||||
|
||||
## File Locations
|
||||
|
||||
- Scripts: `scripts/paper_manager.py`
|
||||
- Templates: `templates/*.md`
|
||||
- Examples: `examples/example_usage.md`
|
||||
- This guide: `references/quick_reference.md`
|
||||
|
||||
## Getting Help
|
||||
|
||||
```bash
|
||||
# Command help
|
||||
uv run scripts/paper_manager.py --help
|
||||
|
||||
# Subcommand help
|
||||
uv run scripts/paper_manager.py link --help
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Full documentation](../SKILL.md)
|
||||
- [Usage examples](../examples/example_usage.md)
|
||||
- [HF Paper Pages](https://huggingface.co/papers)
|
||||
- [tfrere's template](https://huggingface.co/spaces/tfrere/research-article-template)
|
||||
@@ -0,0 +1,606 @@
|
||||
#!/usr/bin/env -S uv run
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "huggingface_hub",
|
||||
# "pyyaml",
|
||||
# "requests",
|
||||
# "python-dotenv",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
Paper Manager for Hugging Face Hub
|
||||
Manages paper indexing, linking, authorship, and article creation.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
from huggingface_hub import HfApi, hf_hub_download, get_token
|
||||
import yaml
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
except ImportError as e:
|
||||
print(f"Error: Missing required dependency: {e}")
|
||||
print("Tip: run this script with `uv run scripts/paper_manager.py ...`.")
|
||||
sys.exit(1)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class PaperManager:
|
||||
"""Manages paper publishing operations on Hugging Face Hub."""
|
||||
|
||||
def __init__(self, hf_token: Optional[str] = None):
|
||||
"""Initialize Paper Manager with HF token."""
|
||||
self.token = hf_token or os.getenv("HF_TOKEN") or get_token()
|
||||
if not self.token:
|
||||
print("Warning: No HF_TOKEN found. Some operations will fail.")
|
||||
self.api = HfApi(token=self.token)
|
||||
|
||||
def index_paper(self, arxiv_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Index a paper on Hugging Face from arXiv.
|
||||
|
||||
Args:
|
||||
arxiv_id: arXiv identifier (e.g., "2301.12345")
|
||||
|
||||
Returns:
|
||||
dict: Status information
|
||||
"""
|
||||
# Clean and validate arXiv ID
|
||||
try:
|
||||
arxiv_id = self._clean_arxiv_id(arxiv_id)
|
||||
except ValueError as e:
|
||||
print(f"Error: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
print(f"Indexing paper {arxiv_id} on Hugging Face...")
|
||||
|
||||
# Check if paper exists
|
||||
paper_url = f"https://huggingface.co/papers/{arxiv_id}"
|
||||
|
||||
try:
|
||||
response = requests.get(paper_url, timeout=10)
|
||||
if response.status_code == 200:
|
||||
print(f"✓ Paper already indexed at {paper_url}")
|
||||
return {"status": "exists", "url": paper_url}
|
||||
else:
|
||||
print(f"Paper not indexed. Visit {paper_url} to trigger indexing.")
|
||||
print("The paper will be automatically indexed when you first visit the URL.")
|
||||
return {"status": "not_indexed", "url": paper_url, "action": "visit_url"}
|
||||
except requests.RequestException as e:
|
||||
print(f"Error checking paper status: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
def check_paper(self, arxiv_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if a paper exists on Hugging Face.
|
||||
|
||||
Args:
|
||||
arxiv_id: arXiv identifier
|
||||
|
||||
Returns:
|
||||
dict: Paper status and metadata
|
||||
"""
|
||||
try:
|
||||
arxiv_id = self._clean_arxiv_id(arxiv_id)
|
||||
except ValueError as e:
|
||||
return {"exists": False, "error": str(e)}
|
||||
paper_url = f"https://huggingface.co/papers/{arxiv_id}"
|
||||
|
||||
try:
|
||||
response = requests.get(paper_url, timeout=10)
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
"exists": True,
|
||||
"url": paper_url,
|
||||
"arxiv_id": arxiv_id,
|
||||
"arxiv_url": f"https://arxiv.org/abs/{arxiv_id}"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"exists": False,
|
||||
"arxiv_id": arxiv_id,
|
||||
"index_url": paper_url,
|
||||
"message": f"Visit {paper_url} to index this paper"
|
||||
}
|
||||
except requests.RequestException as e:
|
||||
return {"exists": False, "error": str(e)}
|
||||
|
||||
def link_paper_to_repo(
|
||||
self,
|
||||
repo_id: str,
|
||||
arxiv_id: str,
|
||||
repo_type: str = "model",
|
||||
citation: Optional[str] = None,
|
||||
create_pr: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Link a paper to a model/dataset/space repository.
|
||||
|
||||
Args:
|
||||
repo_id: Repository identifier (e.g., "username/repo-name")
|
||||
arxiv_id: arXiv identifier
|
||||
repo_type: Type of repository ("model", "dataset", or "space")
|
||||
citation: Optional full citation text
|
||||
create_pr: Create a PR instead of direct commit
|
||||
|
||||
Returns:
|
||||
dict: Operation status
|
||||
"""
|
||||
try:
|
||||
arxiv_id = self._clean_arxiv_id(arxiv_id)
|
||||
except ValueError as e:
|
||||
print(f"Error: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
print(f"Linking paper {arxiv_id} to {repo_type} {repo_id}...")
|
||||
|
||||
try:
|
||||
# Download current README
|
||||
readme_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="README.md",
|
||||
repo_type=repo_type,
|
||||
token=self.token
|
||||
)
|
||||
|
||||
with open(readme_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# Parse or create YAML frontmatter
|
||||
updated_content = self._add_paper_to_readme(content, arxiv_id, citation)
|
||||
|
||||
# Upload updated README
|
||||
commit_message = f"Add paper reference: arXiv:{arxiv_id}"
|
||||
|
||||
if create_pr:
|
||||
# Create PR (not implemented in basic version)
|
||||
print("PR creation not yet implemented. Committing directly.")
|
||||
|
||||
self.api.upload_file(
|
||||
path_or_fileobj=updated_content.encode('utf-8'),
|
||||
path_in_repo="README.md",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message=commit_message,
|
||||
token=self.token
|
||||
)
|
||||
|
||||
paper_url = f"https://huggingface.co/papers/{arxiv_id}"
|
||||
repo_url = f"https://huggingface.co/{repo_id}"
|
||||
|
||||
print(f"✓ Successfully linked paper to repository")
|
||||
print(f" Paper: {paper_url}")
|
||||
print(f" Repo: {repo_url}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"paper_url": paper_url,
|
||||
"repo_url": repo_url,
|
||||
"arxiv_id": arxiv_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error linking paper: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
def _add_paper_to_readme(
|
||||
self,
|
||||
content: str,
|
||||
arxiv_id: str,
|
||||
citation: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Add paper reference to README content.
|
||||
|
||||
Args:
|
||||
content: Current README content
|
||||
arxiv_id: arXiv identifier
|
||||
citation: Optional citation text
|
||||
|
||||
Returns:
|
||||
str: Updated README content
|
||||
"""
|
||||
arxiv_url = f"https://arxiv.org/abs/{arxiv_id}"
|
||||
hf_paper_url = f"https://huggingface.co/papers/{arxiv_id}"
|
||||
|
||||
# Check if YAML frontmatter exists
|
||||
yaml_pattern = r'^---\s*\n(.*?)\n---\s*\n'
|
||||
match = re.match(yaml_pattern, content, re.DOTALL)
|
||||
|
||||
if match:
|
||||
# YAML exists, check if paper already referenced
|
||||
if arxiv_id in content:
|
||||
print(f"Paper {arxiv_id} already referenced in README")
|
||||
return content
|
||||
|
||||
# Add to existing content (after YAML)
|
||||
yaml_end = match.end()
|
||||
before = content[:yaml_end]
|
||||
after = content[yaml_end:]
|
||||
else:
|
||||
# No YAML, add minimal frontmatter
|
||||
yaml_content = "---\n---\n\n"
|
||||
before = yaml_content
|
||||
after = content
|
||||
|
||||
# Add paper reference section with boundary markers
|
||||
paper_section = "\n<!-- paper-manager:start -->\n"
|
||||
paper_section += f"## Paper\n\n"
|
||||
paper_section += f"This {'model' if 'model' in content.lower() else 'work'} is based on research presented in:\n\n"
|
||||
paper_section += f"**[View on arXiv]({arxiv_url})** | "
|
||||
paper_section += f"**[View on Hugging Face]({hf_paper_url})**\n\n"
|
||||
|
||||
if citation:
|
||||
safe_citation = self._sanitize_text(citation)
|
||||
paper_section += f"### Citation\n\n```bibtex\n{safe_citation}\n```\n\n"
|
||||
|
||||
paper_section += "<!-- paper-manager:end -->\n"
|
||||
|
||||
# Insert after YAML, before main content
|
||||
updated_content = before + paper_section + after
|
||||
|
||||
return updated_content
|
||||
|
||||
def create_research_article(
|
||||
self,
|
||||
template: str,
|
||||
title: str,
|
||||
output: str,
|
||||
authors: Optional[str] = None,
|
||||
abstract: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a research article from template.
|
||||
|
||||
Args:
|
||||
template: Template name ("standard", "modern", "arxiv", "ml-report")
|
||||
title: Paper title
|
||||
output: Output filename
|
||||
authors: Comma-separated author names
|
||||
abstract: Abstract text
|
||||
|
||||
Returns:
|
||||
dict: Creation status
|
||||
"""
|
||||
print(f"Creating research article with '{template}' template...")
|
||||
|
||||
# Load template
|
||||
template_dir = Path(__file__).parent.parent / "templates"
|
||||
template_file = template_dir / f"{template}.md"
|
||||
|
||||
if not template_file.exists():
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Template '{template}' not found at {template_file}"
|
||||
}
|
||||
|
||||
with open(template_file, 'r', encoding='utf-8') as f:
|
||||
template_content = f.read()
|
||||
|
||||
# Prepare safe values for different contexts
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
safe_title_body = self._sanitize_text(title)
|
||||
authors_val = authors if authors else "Your Name"
|
||||
safe_authors_body = self._sanitize_text(authors_val)
|
||||
abstract_val = abstract if abstract else "Abstract to be written..."
|
||||
safe_abstract_body = self._sanitize_text(abstract_val)
|
||||
|
||||
# Split frontmatter from body for context-aware escaping
|
||||
fm_pattern = r'^(---\s*\n)(.*?\n)(---\s*\n)'
|
||||
fm_match = re.match(fm_pattern, template_content, re.DOTALL)
|
||||
|
||||
if fm_match:
|
||||
fm_open, fm_body, fm_close = fm_match.group(1), fm_match.group(2), fm_match.group(3)
|
||||
body = template_content[fm_match.end():]
|
||||
|
||||
# YAML-escape values in frontmatter
|
||||
fm_body = fm_body.replace("{{TITLE}}", self._escape_yaml_value(title))
|
||||
fm_body = fm_body.replace("{{AUTHORS}}", self._escape_yaml_value(authors_val))
|
||||
fm_body = fm_body.replace("{{DATE}}", date_str)
|
||||
|
||||
# Sanitize values in body
|
||||
body = body.replace("{{TITLE}}", safe_title_body)
|
||||
body = body.replace("{{AUTHORS}}", safe_authors_body)
|
||||
body = body.replace("{{ABSTRACT}}", safe_abstract_body)
|
||||
body = body.replace("{{DATE}}", date_str)
|
||||
|
||||
content = fm_open + fm_body + fm_close + body
|
||||
else:
|
||||
# No frontmatter — sanitize everything
|
||||
content = template_content.replace("{{TITLE}}", safe_title_body)
|
||||
content = content.replace("{{DATE}}", date_str)
|
||||
content = content.replace("{{AUTHORS}}", safe_authors_body)
|
||||
content = content.replace("{{ABSTRACT}}", safe_abstract_body)
|
||||
|
||||
# Write output
|
||||
with open(output, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
print(f"✓ Research article created at {output}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"output": output,
|
||||
"template": template
|
||||
}
|
||||
|
||||
def get_arxiv_info(self, arxiv_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch paper information from arXiv API.
|
||||
|
||||
Args:
|
||||
arxiv_id: arXiv identifier
|
||||
|
||||
Returns:
|
||||
dict: Paper metadata
|
||||
"""
|
||||
try:
|
||||
arxiv_id = self._clean_arxiv_id(arxiv_id)
|
||||
except ValueError as e:
|
||||
return {"error": str(e)}
|
||||
api_url = f"https://export.arxiv.org/api/query?id_list={arxiv_id}"
|
||||
|
||||
try:
|
||||
response = requests.get(api_url, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse XML response (simplified)
|
||||
content = response.text
|
||||
|
||||
# Extract basic info with regex (proper XML parsing would be better)
|
||||
title_match = re.search(r'<title>(.*?)</title>', content, re.DOTALL)
|
||||
authors_matches = re.findall(r'<name>(.*?)</name>', content)
|
||||
summary_match = re.search(r'<summary>(.*?)</summary>', content, re.DOTALL)
|
||||
|
||||
# Sanitize all text extracted from the external API
|
||||
raw_title = title_match.group(1).strip() if title_match else None
|
||||
raw_authors = authors_matches[1:] if len(authors_matches) > 1 else []
|
||||
raw_abstract = summary_match.group(1).strip() if summary_match else None
|
||||
|
||||
return {
|
||||
"arxiv_id": arxiv_id,
|
||||
"title": self._sanitize_text(raw_title) if raw_title else None,
|
||||
"authors": [self._sanitize_text(a) for a in raw_authors],
|
||||
"abstract": self._sanitize_text(raw_abstract) if raw_abstract else None,
|
||||
"arxiv_url": f"https://arxiv.org/abs/{arxiv_id}",
|
||||
"pdf_url": f"https://arxiv.org/pdf/{arxiv_id}.pdf"
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
def generate_citation(
|
||||
self,
|
||||
arxiv_id: str,
|
||||
format: str = "bibtex"
|
||||
) -> str:
|
||||
"""
|
||||
Generate citation for a paper.
|
||||
|
||||
Args:
|
||||
arxiv_id: arXiv identifier
|
||||
format: Citation format ("bibtex", "apa", "mla")
|
||||
|
||||
Returns:
|
||||
str: Formatted citation
|
||||
"""
|
||||
try:
|
||||
arxiv_id = self._clean_arxiv_id(arxiv_id)
|
||||
except ValueError as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
info = self.get_arxiv_info(arxiv_id)
|
||||
|
||||
if "error" in info:
|
||||
return f"Error fetching paper info: {info['error']}"
|
||||
|
||||
if format == "bibtex":
|
||||
# Generate BibTeX citation
|
||||
key = f"arxiv{arxiv_id.replace('.', '_')}"
|
||||
raw_authors = " and ".join(info.get("authors", ["Unknown"]))
|
||||
raw_title = info.get("title", "Untitled")
|
||||
year = arxiv_id.split(".")[0][:2] # Extract year from ID (simplified)
|
||||
year = f"20{year}" if int(year) < 50 else f"19{year}"
|
||||
|
||||
# Escape BibTeX structural characters in untrusted values
|
||||
safe_title = raw_title.replace('{', r'\{').replace('}', r'\}')
|
||||
safe_authors = raw_authors.replace('{', r'\{').replace('}', r'\}')
|
||||
|
||||
citation = f"""@article{{{key},
|
||||
title={{{safe_title}}},
|
||||
author={{{safe_authors}}},
|
||||
journal={{arXiv preprint arXiv:{arxiv_id}}},
|
||||
year={{{year}}}
|
||||
}}"""
|
||||
return citation
|
||||
|
||||
return f"Format '{format}' not yet implemented"
|
||||
|
||||
# Patterns for valid arXiv IDs
|
||||
_ARXIV_ID_MODERN = re.compile(r'^\d{4}\.\d{4,5}(v\d+)?$')
|
||||
_ARXIV_ID_LEGACY = re.compile(r'^[a-zA-Z\-]+/\d{7}(v\d+)?$')
|
||||
|
||||
@staticmethod
|
||||
def _clean_arxiv_id(arxiv_id: str) -> str:
|
||||
"""Clean, normalize, and validate arXiv ID.
|
||||
|
||||
Raises:
|
||||
ValueError: If the cleaned ID does not match a valid arXiv format.
|
||||
"""
|
||||
# Remove common prefixes and whitespace
|
||||
arxiv_id = arxiv_id.strip()
|
||||
arxiv_id = re.sub(r'^(arxiv:|arXiv:)', '', arxiv_id, flags=re.IGNORECASE)
|
||||
arxiv_id = re.sub(r'https?://arxiv\.org/(abs|pdf)/', '', arxiv_id)
|
||||
arxiv_id = arxiv_id.replace('.pdf', '')
|
||||
|
||||
# Validate format
|
||||
if not (PaperManager._ARXIV_ID_MODERN.match(arxiv_id)
|
||||
or PaperManager._ARXIV_ID_LEGACY.match(arxiv_id)):
|
||||
raise ValueError(
|
||||
f"Invalid arXiv ID: {arxiv_id!r}. "
|
||||
"Expected format: YYMM.NNNNN[vN] or category/YYMMNNN[vN]"
|
||||
)
|
||||
|
||||
return arxiv_id
|
||||
|
||||
@staticmethod
|
||||
def _escape_yaml_value(value: str) -> str:
|
||||
"""Escape a string for safe use as a YAML scalar value.
|
||||
|
||||
Wraps in double quotes and escapes internal quotes and backslashes
|
||||
to prevent YAML injection via crafted titles/authors.
|
||||
"""
|
||||
value = value.replace('\\', '\\\\').replace('"', '\\"')
|
||||
return f'"{value}"'
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_text(text: str) -> str:
|
||||
"""Sanitize untrusted text for safe inclusion in Markdown/YAML output.
|
||||
|
||||
Normalizes whitespace, strips control characters, and neutralizes
|
||||
markdown code-fence breakout and YAML document delimiters.
|
||||
"""
|
||||
# Remove control characters (keep newlines and tabs)
|
||||
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
|
||||
# Normalize whitespace runs (collapse multiple spaces/tabs, preserve single newlines)
|
||||
text = re.sub(r'[^\S\n]+', ' ', text)
|
||||
text = re.sub(r'\n{3,}', '\n\n', text)
|
||||
# Neutralize markdown code fence breakout
|
||||
text = text.replace('```', r'\`\`\`')
|
||||
# Neutralize YAML document delimiters at line start
|
||||
text = re.sub(r'^---', r'\\---', text, flags=re.MULTILINE)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main CLI entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Paper Manager for Hugging Face Hub",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to execute")
|
||||
|
||||
# Index command
|
||||
index_parser = subparsers.add_parser("index", help="Index a paper from arXiv")
|
||||
index_parser.add_argument("--arxiv-id", required=True, help="arXiv paper ID")
|
||||
|
||||
# Check command
|
||||
check_parser = subparsers.add_parser("check", help="Check if paper exists")
|
||||
check_parser.add_argument("--arxiv-id", required=True, help="arXiv paper ID")
|
||||
|
||||
# Link command
|
||||
link_parser = subparsers.add_parser("link", help="Link paper to repository")
|
||||
link_parser.add_argument("--repo-id", required=True, help="Repository ID")
|
||||
link_parser.add_argument("--repo-type", default="model", choices=["model", "dataset", "space"])
|
||||
link_parser.add_argument("--arxiv-id", help="Single arXiv ID")
|
||||
link_parser.add_argument("--arxiv-ids", help="Comma-separated arXiv IDs")
|
||||
link_parser.add_argument("--citation", help="Full citation text")
|
||||
link_parser.add_argument("--create-pr", action="store_true", help="Create PR instead of direct commit")
|
||||
|
||||
# Create command
|
||||
create_parser = subparsers.add_parser("create", help="Create research article")
|
||||
create_parser.add_argument("--template", required=True, help="Template name")
|
||||
create_parser.add_argument("--title", required=True, help="Paper title")
|
||||
create_parser.add_argument("--output", required=True, help="Output filename")
|
||||
create_parser.add_argument("--authors", help="Comma-separated authors")
|
||||
create_parser.add_argument("--abstract", help="Abstract text")
|
||||
|
||||
# Info command
|
||||
info_parser = subparsers.add_parser("info", help="Get paper information")
|
||||
info_parser.add_argument("--arxiv-id", required=True, help="arXiv paper ID")
|
||||
info_parser.add_argument("--format", default="json", choices=["json", "text"])
|
||||
|
||||
# Citation command
|
||||
citation_parser = subparsers.add_parser("citation", help="Generate citation")
|
||||
citation_parser.add_argument("--arxiv-id", required=True, help="arXiv paper ID")
|
||||
citation_parser.add_argument("--format", default="bibtex", choices=["bibtex", "apa", "mla"])
|
||||
|
||||
# Search command
|
||||
search_parser = subparsers.add_parser("search", help="Search papers")
|
||||
search_parser.add_argument("--query", required=True, help="Search query")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
# Initialize manager
|
||||
manager = PaperManager()
|
||||
|
||||
# Execute command
|
||||
if args.command == "index":
|
||||
result = manager.index_paper(args.arxiv_id)
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
elif args.command == "check":
|
||||
result = manager.check_paper(args.arxiv_id)
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
elif args.command == "link":
|
||||
arxiv_ids = []
|
||||
if args.arxiv_id:
|
||||
arxiv_ids.append(args.arxiv_id)
|
||||
if args.arxiv_ids:
|
||||
arxiv_ids.extend([id.strip() for id in args.arxiv_ids.split(",")])
|
||||
|
||||
if not arxiv_ids:
|
||||
print("Error: Must provide --arxiv-id or --arxiv-ids")
|
||||
sys.exit(1)
|
||||
|
||||
for arxiv_id in arxiv_ids:
|
||||
result = manager.link_paper_to_repo(
|
||||
repo_id=args.repo_id,
|
||||
arxiv_id=arxiv_id,
|
||||
repo_type=args.repo_type,
|
||||
citation=args.citation,
|
||||
create_pr=args.create_pr
|
||||
)
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
elif args.command == "create":
|
||||
result = manager.create_research_article(
|
||||
template=args.template,
|
||||
title=args.title,
|
||||
output=args.output,
|
||||
authors=args.authors,
|
||||
abstract=args.abstract
|
||||
)
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
elif args.command == "info":
|
||||
result = manager.get_arxiv_info(args.arxiv_id)
|
||||
if args.format == "json":
|
||||
print(json.dumps(result, indent=2))
|
||||
else:
|
||||
if "error" in result:
|
||||
print(f"Error: {result['error']}")
|
||||
else:
|
||||
print(f"Title: {result.get('title')}")
|
||||
print(f"Authors: {', '.join(result.get('authors', []))}")
|
||||
print(f"arXiv URL: {result.get('arxiv_url')}")
|
||||
print(f"\nAbstract:\n{result.get('abstract')}")
|
||||
|
||||
elif args.command == "citation":
|
||||
citation = manager.generate_citation(args.arxiv_id, args.format)
|
||||
print(citation)
|
||||
|
||||
elif args.command == "search":
|
||||
print(f"Searching for: {args.query}")
|
||||
print("Search functionality coming soon!")
|
||||
print(f"Visit: https://huggingface.co/papers?search={args.query}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,299 @@
|
||||
---
|
||||
title: {{TITLE}}
|
||||
authors: {{AUTHORS}}
|
||||
date: {{DATE}}
|
||||
arxiv:
|
||||
layout: arxiv
|
||||
---
|
||||
|
||||
# {{TITLE}}
|
||||
|
||||
<div class="arxiv-header">
|
||||
|
||||
**{{AUTHORS}}**
|
||||
|
||||
*Submitted to arXiv: {{DATE}}*
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
**Abstract**—{{ABSTRACT}}
|
||||
|
||||
**Index Terms**—Machine Learning, Deep Learning, Neural Networks
|
||||
|
||||
---
|
||||
|
||||
## I. INTRODUCTION
|
||||
|
||||
**T**HIS paper presents [brief overview of the contribution]. The main contributions of this work are:
|
||||
|
||||
- Contribution 1: Description
|
||||
- Contribution 2: Description
|
||||
- Contribution 3: Description
|
||||
|
||||
The rest of this paper is organized as follows: Section II reviews related work, Section III describes the proposed methodology, Section IV presents experimental results, and Section V concludes the paper.
|
||||
|
||||
## II. RELATED WORK
|
||||
|
||||
### A. Subarea 1
|
||||
|
||||
Discussion of relevant prior work in subarea 1.
|
||||
|
||||
### B. Subarea 2
|
||||
|
||||
Discussion of relevant prior work in subarea 2.
|
||||
|
||||
### C. Comparison with Prior Art
|
||||
|
||||
Table comparing existing methods:
|
||||
|
||||
| Method | Year | Approach | Limitation |
|
||||
|--------|------|----------|------------|
|
||||
| Method A [1] | 2020 | Description | Issue |
|
||||
| Method B [2] | 2021 | Description | Issue |
|
||||
| Method C [3] | 2023 | Description | Issue |
|
||||
|
||||
## III. METHODOLOGY
|
||||
|
||||
### A. Problem Formulation
|
||||
|
||||
Let $X = \{x_1, x_2, ..., x_n\}$ be the input space and $Y = \{y_1, y_2, ..., y_m\}$ be the output space. We aim to learn a function $f: X \rightarrow Y$ that minimizes:
|
||||
|
||||
$$
|
||||
\mathcal{L}(\theta) = \sum_{i=1}^{N} \ell(f(x_i; \theta), y_i) + \lambda R(\theta)
|
||||
$$
|
||||
|
||||
where $\theta$ represents model parameters, $\ell$ is the loss function, and $R(\theta)$ is a regularization term.
|
||||
|
||||
### B. Model Architecture
|
||||
|
||||
Describe the model architecture in detail.
|
||||
|
||||
**Input Layer**: Description
|
||||
|
||||
**Hidden Layers**: Let $h^{(l)}$ denote the activation of layer $l$:
|
||||
|
||||
$$
|
||||
h^{(l)} = \sigma(W^{(l)}h^{(l-1)} + b^{(l)})
|
||||
$$
|
||||
|
||||
where $\sigma$ is the activation function, $W^{(l)}$ is the weight matrix, and $b^{(l)}$ is the bias vector.
|
||||
|
||||
**Output Layer**: Description
|
||||
|
||||
### C. Training Algorithm
|
||||
|
||||
**Algorithm 1**: Training Procedure
|
||||
|
||||
```
|
||||
1: Input: Training data D = {(xi, yi)}
|
||||
2: Initialize parameters θ
|
||||
3: for epoch = 1 to max_epochs do
|
||||
4: for each mini-batch B ⊂ D do
|
||||
5: Compute loss: L(θ) = 1/|B| Σ ℓ(f(xi; θ), yi)
|
||||
6: Update: θ ← θ - η∇θL(θ)
|
||||
7: end for
|
||||
8: end for
|
||||
9: Return: Trained parameters θ*
|
||||
```
|
||||
|
||||
### D. Complexity Analysis
|
||||
|
||||
**Time Complexity**: The training algorithm has time complexity $O(NTE)$ where $N$ is the dataset size, $T$ is the number of epochs, and $E$ is the per-example computation cost.
|
||||
|
||||
**Space Complexity**: The model requires $O(P)$ space where $P$ is the number of parameters.
|
||||
|
||||
## IV. EXPERIMENTS
|
||||
|
||||
### A. Experimental Setup
|
||||
|
||||
**Datasets**: We evaluate on the following benchmarks:
|
||||
|
||||
1. **Dataset A**: Description (size, splits, characteristics)
|
||||
2. **Dataset B**: Description
|
||||
3. **Dataset C**: Description
|
||||
|
||||
**Baselines**: We compare against:
|
||||
|
||||
- Baseline 1 [4]: Description
|
||||
- Baseline 2 [5]: Description
|
||||
- Baseline 3 [6]: Description
|
||||
|
||||
**Evaluation Metrics**: Performance is measured using:
|
||||
|
||||
- Metric 1: Definition
|
||||
- Metric 2: Definition
|
||||
- Metric 3: Definition
|
||||
|
||||
**Implementation Details**: All experiments are conducted using:
|
||||
|
||||
- Framework: PyTorch 2.0
|
||||
- Hardware: NVIDIA A100 GPUs
|
||||
- Hyperparameters: Learning rate $\eta = 10^{-4}$, batch size $B = 32$, epochs $T = 100$
|
||||
|
||||
### B. Quantitative Results
|
||||
|
||||
**TABLE I: MAIN RESULTS**
|
||||
|
||||
| Method | Dataset A | Dataset B | Dataset C | Average |
|
||||
|--------|-----------|-----------|-----------|---------|
|
||||
| Baseline 1 [4] | 82.3 | 78.5 | 80.1 | 80.3 |
|
||||
| Baseline 2 [5] | 85.7 | 82.1 | 83.9 | 83.9 |
|
||||
| Baseline 3 [6] | 88.1 | 85.3 | 86.7 | 86.7 |
|
||||
| **Ours** | **91.2** | **88.9** | **90.1** | **90.1** |
|
||||
|
||||
Our method achieves state-of-the-art performance across all three benchmarks, with an average improvement of 3.4 percentage points over the previous best method.
|
||||
|
||||
### C. Ablation Study
|
||||
|
||||
**TABLE II: ABLATION STUDY RESULTS**
|
||||
|
||||
| Configuration | Dataset A | Δ |
|
||||
|---------------|-----------|---|
|
||||
| Full Model | 91.2 | - |
|
||||
| w/o Component A | 88.7 | -2.5 |
|
||||
| w/o Component B | 89.4 | -1.8 |
|
||||
| w/o Component C | 90.5 | -0.7 |
|
||||
|
||||
The ablation study demonstrates that all components contribute to the final performance, with Component A having the largest impact.
|
||||
|
||||
### D. Qualitative Analysis
|
||||
|
||||
**Fig. 1**: Visualization of learned representations using t-SNE projection.
|
||||
|
||||
**Fig. 2**: Example predictions showing correct classifications and failure cases.
|
||||
|
||||
### E. Computational Efficiency
|
||||
|
||||
**TABLE III: COMPUTATIONAL REQUIREMENTS**
|
||||
|
||||
| Method | Parameters | FLOPs | Inference (ms) |
|
||||
|--------|------------|-------|----------------|
|
||||
| Baseline 1 [4] | 50M | 10G | 8.2 |
|
||||
| Baseline 2 [5] | 100M | 25G | 15.7 |
|
||||
| Baseline 3 [6] | 200M | 50G | 28.3 |
|
||||
| **Ours** | **80M** | **18G** | **12.1** |
|
||||
|
||||
Our method achieves superior performance while maintaining reasonable computational costs.
|
||||
|
||||
## V. DISCUSSION
|
||||
|
||||
### A. Analysis of Results
|
||||
|
||||
The experimental results demonstrate that [analysis].
|
||||
|
||||
### B. Limitations
|
||||
|
||||
Current limitations include:
|
||||
|
||||
1. Limitation 1: Description
|
||||
2. Limitation 2: Description
|
||||
3. Limitation 3: Description
|
||||
|
||||
### C. Broader Impact
|
||||
|
||||
Potential applications include:
|
||||
|
||||
- Application 1: Description
|
||||
- Application 2: Description
|
||||
- Application 3: Description
|
||||
|
||||
**Ethical Considerations**: [Discussion of potential risks and mitigation strategies]
|
||||
|
||||
## VI. CONCLUSION
|
||||
|
||||
This paper presented {{TITLE}}, which achieves [main achievement]. The key contributions are:
|
||||
|
||||
1. Contribution 1: Summary
|
||||
2. Contribution 2: Summary
|
||||
3. Contribution 3: Summary
|
||||
|
||||
Future work will focus on [future directions].
|
||||
|
||||
## ACKNOWLEDGMENTS
|
||||
|
||||
The authors thank [acknowledgments]. This work was supported by [funding sources].
|
||||
|
||||
## REFERENCES
|
||||
|
||||
[1] Author A et al., "Paper Title," *Conference Name*, 2020.
|
||||
|
||||
[2] Author B et al., "Paper Title," *Journal Name*, vol. X, no. Y, pp. Z-W, 2021.
|
||||
|
||||
[3] Author C et al., "Paper Title," *arXiv preprint arXiv:XXXX.XXXXX*, 2023.
|
||||
|
||||
[4] Author D et al., "Baseline 1 Paper," *Conference*, 2019.
|
||||
|
||||
[5] Author E et al., "Baseline 2 Paper," *Conference*, 2021.
|
||||
|
||||
[6] Author F et al., "Baseline 3 Paper," *Conference*, 2023.
|
||||
|
||||
---
|
||||
|
||||
## APPENDIX A: ADDITIONAL EXPERIMENTS
|
||||
|
||||
Supplementary experimental results.
|
||||
|
||||
## APPENDIX B: PROOF OF THEOREM
|
||||
|
||||
**Theorem 1**: Statement of theorem.
|
||||
|
||||
**Proof**: Detailed proof.
|
||||
|
||||
## APPENDIX C: HYPERPARAMETERS
|
||||
|
||||
Complete list of hyperparameters used in all experiments:
|
||||
|
||||
| Hyperparameter | Value | Description |
|
||||
|----------------|-------|-------------|
|
||||
| Learning rate | $10^{-4}$ | Initial learning rate |
|
||||
| Batch size | 32 | Training batch size |
|
||||
| Epochs | 100 | Number of training epochs |
|
||||
| Optimizer | AdamW | Optimization algorithm |
|
||||
| Weight decay | 0.01 | L2 regularization coefficient |
|
||||
| Warmup steps | 1000 | LR warmup duration |
|
||||
| Dropout | 0.1 | Dropout probability |
|
||||
|
||||
---
|
||||
|
||||
<style>
|
||||
.arxiv-header {
|
||||
text-align: center;
|
||||
margin-bottom: 2em;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'Computer Modern', serif;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
h1 {
|
||||
text-align: center;
|
||||
font-size: 1.8em;
|
||||
margin-top: 1em;
|
||||
}
|
||||
|
||||
h2 {
|
||||
font-size: 1.3em;
|
||||
margin-top: 1.5em;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
h3 {
|
||||
font-size: 1.1em;
|
||||
font-style: italic;
|
||||
margin-top: 1em;
|
||||
}
|
||||
|
||||
table {
|
||||
margin: 1em auto;
|
||||
border-collapse: collapse;
|
||||
}
|
||||
|
||||
th, td {
|
||||
border: 1px solid #000;
|
||||
padding: 0.5em;
|
||||
text-align: center;
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,358 @@
|
||||
---
|
||||
title: {{TITLE}}
|
||||
authors: {{AUTHORS}}
|
||||
date: {{DATE}}
|
||||
type: ml-experiment-report
|
||||
tags: [machine-learning, experiment-report]
|
||||
---
|
||||
|
||||
# {{TITLE}}
|
||||
|
||||
**Machine Learning Experiment Report**
|
||||
|
||||
**Researchers**: {{AUTHORS}}
|
||||
**Date**: {{DATE}}
|
||||
**Status**: Draft / Final / In Review
|
||||
|
||||
---
|
||||
|
||||
## Executive Summary
|
||||
|
||||
{{ABSTRACT}}
|
||||
|
||||
### Key Findings
|
||||
- Finding 1
|
||||
- Finding 2
|
||||
- Finding 3
|
||||
|
||||
### Recommendations
|
||||
- Recommendation 1
|
||||
- Recommendation 2
|
||||
|
||||
---
|
||||
|
||||
## 1. Objective
|
||||
|
||||
### 1.1 Research Question
|
||||
|
||||
What specific question are we trying to answer?
|
||||
|
||||
### 1.2 Success Criteria
|
||||
|
||||
How will we measure success?
|
||||
|
||||
- **Metric 1**: Target value
|
||||
- **Metric 2**: Target value
|
||||
- **Metric 3**: Target value
|
||||
|
||||
### 1.3 Constraints
|
||||
|
||||
- Computational budget
|
||||
- Time constraints
|
||||
- Data availability
|
||||
|
||||
---
|
||||
|
||||
## 2. Dataset
|
||||
|
||||
### 2.1 Data Description
|
||||
|
||||
| Property | Value |
|
||||
|----------|-------|
|
||||
| **Name** | Dataset name |
|
||||
| **Source** | Origin of data |
|
||||
| **Size** | Number of examples |
|
||||
| **Features** | Feature count and types |
|
||||
| **Target** | What we're predicting |
|
||||
| **License** | Usage rights |
|
||||
|
||||
### 2.2 Data Splits
|
||||
|
||||
| Split | Size | Percentage |
|
||||
|-------|------|------------|
|
||||
| Train | X examples | Y% |
|
||||
| Validation | X examples | Y% |
|
||||
| Test | X examples | Y% |
|
||||
|
||||
### 2.3 Data Quality
|
||||
|
||||
- **Missing Values**: Analysis and handling
|
||||
- **Outliers**: Detection and treatment
|
||||
- **Imbalance**: Class distribution
|
||||
- **Preprocessing**: Transformations applied
|
||||
|
||||
### 2.4 Exploratory Analysis
|
||||
|
||||
Key insights from data exploration:
|
||||
|
||||
1. Pattern 1
|
||||
2. Pattern 2
|
||||
3. Pattern 3
|
||||
|
||||
---
|
||||
|
||||
## 3. Model
|
||||
|
||||
### 3.1 Architecture
|
||||
|
||||
Describe the model architecture:
|
||||
|
||||
```
|
||||
Input → Layer 1 → Layer 2 → ... → Output
|
||||
```
|
||||
|
||||
### 3.2 Model Specifications
|
||||
|
||||
| Component | Configuration |
|
||||
|-----------|--------------|
|
||||
| **Type** | Model family |
|
||||
| **Parameters** | Total count |
|
||||
| **Layers** | Number and types |
|
||||
| **Activation** | Functions used |
|
||||
| **Dropout** | Regularization rate |
|
||||
|
||||
### 3.3 Baseline Models
|
||||
|
||||
What are we comparing against?
|
||||
|
||||
1. **Baseline 1**: Simple baseline (e.g., majority class)
|
||||
2. **Baseline 2**: Standard approach (e.g., logistic regression)
|
||||
3. **Baseline 3**: Previous best method
|
||||
|
||||
---
|
||||
|
||||
## 4. Training
|
||||
|
||||
### 4.1 Hyperparameters
|
||||
|
||||
| Hyperparameter | Value | Rationale |
|
||||
|----------------|-------|-----------|
|
||||
| Learning Rate | 1e-4 | Tuned via grid search |
|
||||
| Batch Size | 32 | GPU memory constraint |
|
||||
| Epochs | 100 | Based on validation |
|
||||
| Optimizer | AdamW | Standard for transformers |
|
||||
| Weight Decay | 0.01 | Regularization |
|
||||
| LR Schedule | Cosine | Smooth convergence |
|
||||
|
||||
### 4.2 Training Process
|
||||
|
||||
```python
|
||||
# Training pseudocode
|
||||
for epoch in range(num_epochs):
|
||||
train_loss = train_one_epoch(model, train_loader)
|
||||
val_loss = validate(model, val_loader)
|
||||
if val_loss < best_loss:
|
||||
save_checkpoint(model)
|
||||
```
|
||||
|
||||
### 4.3 Computational Resources
|
||||
|
||||
| Resource | Specification |
|
||||
|----------|--------------|
|
||||
| **Hardware** | GPU model and count |
|
||||
| **Memory** | RAM and VRAM |
|
||||
| **Training Time** | Hours/days |
|
||||
| **Cost** | Estimated compute cost |
|
||||
|
||||
### 4.4 Training Curves
|
||||
|
||||
Include plots of:
|
||||
- Training loss over time
|
||||
- Validation loss over time
|
||||
- Learning rate schedule
|
||||
- Other relevant metrics
|
||||
|
||||
---
|
||||
|
||||
## 5. Results
|
||||
|
||||
### 5.1 Quantitative Results
|
||||
|
||||
| Model | Accuracy | Precision | Recall | F1 | AUC |
|
||||
|-------|----------|-----------|--------|-------|-----|
|
||||
| Baseline 1 | 0.65 | 0.64 | 0.66 | 0.65 | 0.70 |
|
||||
| Baseline 2 | 0.78 | 0.77 | 0.79 | 0.78 | 0.82 |
|
||||
| **Ours** | **0.89** | **0.88** | **0.90** | **0.89** | **0.93** |
|
||||
|
||||
### 5.2 Statistical Significance
|
||||
|
||||
- **P-value**: Statistical test results
|
||||
- **Confidence Intervals**: 95% CI for key metrics
|
||||
- **Multiple Runs**: Mean ± std over N runs
|
||||
|
||||
### 5.3 Per-Class Performance
|
||||
|
||||
| Class | Precision | Recall | F1 | Support |
|
||||
|-------|-----------|--------|-----|---------|
|
||||
| Class 1 | 0.90 | 0.88 | 0.89 | 500 |
|
||||
| Class 2 | 0.87 | 0.91 | 0.89 | 450 |
|
||||
| Class 3 | 0.88 | 0.89 | 0.88 | 550 |
|
||||
|
||||
### 5.4 Qualitative Results
|
||||
|
||||
#### Success Cases
|
||||
|
||||
Examples where the model performs well.
|
||||
|
||||
#### Failure Cases
|
||||
|
||||
Examples where the model fails and why.
|
||||
|
||||
---
|
||||
|
||||
## 6. Analysis
|
||||
|
||||
### 6.1 Ablation Study
|
||||
|
||||
| Configuration | Score | Change |
|
||||
|---------------|-------|--------|
|
||||
| Full Model | 0.89 | - |
|
||||
| - Feature Set A | 0.85 | -0.04 |
|
||||
| - Feature Set B | 0.87 | -0.02 |
|
||||
| - Augmentation | 0.86 | -0.03 |
|
||||
|
||||
### 6.2 Error Analysis
|
||||
|
||||
What types of errors is the model making?
|
||||
|
||||
1. **Error Type 1**: Frequency and cause
|
||||
2. **Error Type 2**: Frequency and cause
|
||||
3. **Error Type 3**: Frequency and cause
|
||||
|
||||
### 6.3 Feature Importance
|
||||
|
||||
Which features matter most?
|
||||
|
||||
| Feature | Importance | Notes |
|
||||
|---------|------------|-------|
|
||||
| Feature 1 | 0.35 | Most predictive |
|
||||
| Feature 2 | 0.28 | Secondary signal |
|
||||
| Feature 3 | 0.15 | Marginal impact |
|
||||
|
||||
---
|
||||
|
||||
## 7. Robustness
|
||||
|
||||
### 7.1 Cross-Dataset Evaluation
|
||||
|
||||
How does the model generalize to other datasets?
|
||||
|
||||
| Dataset | Score | Notes |
|
||||
|---------|-------|-------|
|
||||
| Original | 0.89 | Training distribution |
|
||||
| Dataset A | 0.82 | Similar domain |
|
||||
| Dataset B | 0.71 | Different domain |
|
||||
|
||||
### 7.2 Adversarial Robustness
|
||||
|
||||
Performance under adversarial conditions.
|
||||
|
||||
### 7.3 Fairness Analysis
|
||||
|
||||
Performance across demographic groups or sensitive attributes.
|
||||
|
||||
---
|
||||
|
||||
## 8. Deployment Considerations
|
||||
|
||||
### 8.1 Model Size
|
||||
|
||||
- **Parameters**: Total count
|
||||
- **Disk Size**: MB/GB on disk
|
||||
- **Memory**: Runtime memory usage
|
||||
|
||||
### 8.2 Inference Speed
|
||||
|
||||
| Batch Size | Latency | Throughput |
|
||||
|------------|---------|------------|
|
||||
| 1 | 10ms | 100 QPS |
|
||||
| 8 | 45ms | 178 QPS |
|
||||
| 32 | 150ms | 213 QPS |
|
||||
|
||||
### 8.3 Production Requirements
|
||||
|
||||
- **Dependencies**: Software requirements
|
||||
- **Infrastructure**: Hardware needs
|
||||
- **Monitoring**: What to track in production
|
||||
- **Fallback**: Backup strategy
|
||||
|
||||
---
|
||||
|
||||
## 9. Conclusions
|
||||
|
||||
### 9.1 Summary
|
||||
|
||||
Key takeaways from the experiment.
|
||||
|
||||
### 9.2 Did We Meet Objectives?
|
||||
|
||||
| Objective | Status | Notes |
|
||||
|-----------|--------|-------|
|
||||
| Objective 1 | ✅ Met | Achieved target |
|
||||
| Objective 2 | ⚠️ Partial | Close to target |
|
||||
| Objective 3 | ❌ Not Met | Needs more work |
|
||||
|
||||
### 9.3 Lessons Learned
|
||||
|
||||
What did we learn from this experiment?
|
||||
|
||||
1. Lesson 1
|
||||
2. Lesson 2
|
||||
3. Lesson 3
|
||||
|
||||
---
|
||||
|
||||
## 10. Next Steps
|
||||
|
||||
### 10.1 Short-term (1-2 weeks)
|
||||
|
||||
- [ ] Task 1
|
||||
- [ ] Task 2
|
||||
- [ ] Task 3
|
||||
|
||||
### 10.2 Medium-term (1-2 months)
|
||||
|
||||
- [ ] Task 1
|
||||
- [ ] Task 2
|
||||
- [ ] Task 3
|
||||
|
||||
### 10.3 Long-term (3+ months)
|
||||
|
||||
- [ ] Task 1
|
||||
- [ ] Task 2
|
||||
- [ ] Task 3
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
1. Reference 1
|
||||
2. Reference 2
|
||||
3. Reference 3
|
||||
|
||||
---
|
||||
|
||||
## Appendix
|
||||
|
||||
### A. Hyperparameter Search
|
||||
|
||||
Results from hyperparameter tuning.
|
||||
|
||||
### B. Additional Experiments
|
||||
|
||||
Supplementary experiments not included in main text.
|
||||
|
||||
### C. Code
|
||||
|
||||
Links to code repositories:
|
||||
- Training code: [link]
|
||||
- Evaluation code: [link]
|
||||
- Model checkpoint: [link]
|
||||
|
||||
### D. Data Card
|
||||
|
||||
Detailed data documentation following standard practices.
|
||||
|
||||
### E. Model Card
|
||||
|
||||
Model documentation following responsible AI practices.
|
||||
@@ -0,0 +1,319 @@
|
||||
---
|
||||
title: {{TITLE}}
|
||||
authors: {{AUTHORS}}
|
||||
date: {{DATE}}
|
||||
arxiv:
|
||||
tags: [machine-learning, ai]
|
||||
layout: modern
|
||||
---
|
||||
|
||||
<div class="header">
|
||||
|
||||
# {{TITLE}}
|
||||
|
||||
<div class="authors">
|
||||
{{AUTHORS}}
|
||||
</div>
|
||||
|
||||
<div class="date">
|
||||
{{DATE}}
|
||||
</div>
|
||||
|
||||
<div class="links">
|
||||
[arXiv](#) · [PDF](#) · [Code](#) · [Demo](#)
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## Abstract
|
||||
|
||||
<div class="abstract">
|
||||
|
||||
{{ABSTRACT}}
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## Introduction
|
||||
|
||||
Modern research requires clear, accessible communication. This template provides a clean, web-friendly format inspired by Distill and modern scientific publications.
|
||||
|
||||
<div class="key-insight">
|
||||
💡 **Key Insight**: Present your main contribution upfront to engage readers immediately.
|
||||
</div>
|
||||
|
||||
### Why This Matters
|
||||
|
||||
Explain the significance of your work in plain language. What real-world problems does it solve?
|
||||
|
||||
### Our Approach
|
||||
|
||||
Summarize your methodology at a high level before diving into details.
|
||||
|
||||
---
|
||||
|
||||
## Background
|
||||
|
||||
<div class="definition">
|
||||
**Definition**: Clearly define key terms and concepts early in the paper.
|
||||
</div>
|
||||
|
||||
Provide context necessary to understand your contribution without overwhelming readers with details.
|
||||
|
||||
### Problem Statement
|
||||
|
||||
Formally state the problem you're addressing.
|
||||
|
||||
### Challenges
|
||||
|
||||
What makes this problem difficult?
|
||||
|
||||
1. **Challenge 1**: Description
|
||||
2. **Challenge 2**: Description
|
||||
3. **Challenge 3**: Description
|
||||
|
||||
---
|
||||
|
||||
## Method
|
||||
|
||||
Present your approach with clear visual aids and intuitive explanations.
|
||||
|
||||
<div class="figure">
|
||||
|
||||
```
|
||||
[Diagram of your architecture goes here]
|
||||
```
|
||||
|
||||
**Figure 1**: Overview of the proposed method. Caption explains the key components.
|
||||
|
||||
</div>
|
||||
|
||||
### Model Architecture
|
||||
|
||||
Describe your model systematically:
|
||||
|
||||
```python
|
||||
# Pseudocode example
|
||||
class YourModel:
|
||||
def __init__(self):
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
|
||||
def forward(self, x):
|
||||
z = self.encoder(x)
|
||||
output = self.decoder(z)
|
||||
return output
|
||||
```
|
||||
|
||||
### Training Strategy
|
||||
|
||||
Explain how you train the model, including:
|
||||
|
||||
- **Objective Function**: Mathematical formulation
|
||||
- **Optimization**: Algorithm and hyperparameters
|
||||
- **Regularization**: Techniques to prevent overfitting
|
||||
|
||||
---
|
||||
|
||||
## Experiments
|
||||
|
||||
### Setup
|
||||
|
||||
<div class="experiment-details">
|
||||
|
||||
| Component | Configuration |
|
||||
|-----------|--------------|
|
||||
| **Dataset** | Name, Size, Split |
|
||||
| **Hardware** | GPU Type, RAM |
|
||||
| **Framework** | PyTorch 2.0, Transformers |
|
||||
| **Training Time** | Hours/Days |
|
||||
|
||||
</div>
|
||||
|
||||
### Results
|
||||
|
||||
Present results clearly with tables and visualizations.
|
||||
|
||||
<div class="results-table">
|
||||
|
||||
| Model | Accuracy | F1 Score | Params | Speed |
|
||||
|-------|----------|----------|--------|-------|
|
||||
| Baseline | 85.2% | 0.84 | 100M | 100 tok/s |
|
||||
| **Ours** | **92.1%** | **0.91** | 120M | 95 tok/s |
|
||||
| SOTA | 90.5% | 0.89 | 300M | 60 tok/s |
|
||||
|
||||
</div>
|
||||
|
||||
<div class="insight">
|
||||
🔍 **Observation**: Our method achieves state-of-the-art performance with fewer parameters.
|
||||
</div>
|
||||
|
||||
### Analysis
|
||||
|
||||
Deep dive into what the results reveal:
|
||||
|
||||
1. **Performance**: How does your method compare?
|
||||
2. **Efficiency**: What are the computational costs?
|
||||
3. **Robustness**: How does it perform across different scenarios?
|
||||
|
||||
---
|
||||
|
||||
## Ablation Study
|
||||
|
||||
Systematically evaluate each component's contribution.
|
||||
|
||||
<div class="ablation-results">
|
||||
|
||||
| Configuration | Score | Δ |
|
||||
|---------------|-------|---|
|
||||
| Full Model | 92.1% | - |
|
||||
| - Component A | 89.3% | -2.8% |
|
||||
| - Component B | 90.1% | -2.0% |
|
||||
| - Component C | 91.5% | -0.6% |
|
||||
|
||||
</div>
|
||||
|
||||
**Conclusion**: All components contribute meaningfully, with Component A being most critical.
|
||||
|
||||
---
|
||||
|
||||
## Discussion
|
||||
|
||||
### What We Learned
|
||||
|
||||
Synthesize insights from your experiments.
|
||||
|
||||
### Limitations
|
||||
|
||||
<div class="limitations">
|
||||
|
||||
⚠️ **Current Limitations**:
|
||||
|
||||
1. Performance on domain X is limited
|
||||
2. Computational requirements are high
|
||||
3. Requires large training datasets
|
||||
|
||||
</div>
|
||||
|
||||
### Future Directions
|
||||
|
||||
Where should the community go next?
|
||||
|
||||
- **Direction 1**: Description
|
||||
- **Direction 2**: Description
|
||||
- **Direction 3**: Description
|
||||
|
||||
---
|
||||
|
||||
## Related Work
|
||||
|
||||
Compare and contrast with existing methods.
|
||||
|
||||
### Prior Approaches
|
||||
|
||||
| Method | Year | Key Idea | Limitation |
|
||||
|--------|------|----------|------------|
|
||||
| Method A | 2020 | Approach 1 | Issue X |
|
||||
| Method B | 2021 | Approach 2 | Issue Y |
|
||||
| Method C | 2023 | Approach 3 | Issue Z |
|
||||
|
||||
### How We Differ
|
||||
|
||||
Clearly articulate what's novel about your work.
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
<div class="conclusion">
|
||||
|
||||
We presented **{{TITLE}}**, which achieves:
|
||||
|
||||
1. ✅ **Main contribution 1**
|
||||
2. ✅ **Main contribution 2**
|
||||
3. ✅ **Main contribution 3**
|
||||
|
||||
Our results demonstrate [key finding], opening new directions for [future work].
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## Reproducibility
|
||||
|
||||
<div class="reproducibility">
|
||||
|
||||
### Code & Data
|
||||
|
||||
- **Code**: [github.com/username/repo](#)
|
||||
- **Models**: [huggingface.co/username/model](#)
|
||||
- **Datasets**: [huggingface.co/datasets/username/dataset](#)
|
||||
- **Demo**: [huggingface.co/spaces/username/demo](#)
|
||||
|
||||
### Citation
|
||||
|
||||
```bibtex
|
||||
@article{yourpaper2025,
|
||||
title={{{{TITLE}}}},
|
||||
author={{{{AUTHORS}}}},
|
||||
year={2025},
|
||||
journal={arXiv preprint}
|
||||
}
|
||||
```
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
Thank funding agencies, collaborators, and computing resources that made this work possible.
|
||||
|
||||
---
|
||||
|
||||
<div class="appendix">
|
||||
|
||||
## Appendix
|
||||
|
||||
### A. Additional Results
|
||||
|
||||
Supplementary experiments and extended results.
|
||||
|
||||
### B. Hyperparameters
|
||||
|
||||
Complete training configuration:
|
||||
|
||||
```yaml
|
||||
learning_rate: 1e-4
|
||||
batch_size: 32
|
||||
epochs: 100
|
||||
optimizer: AdamW
|
||||
scheduler: cosine
|
||||
warmup_steps: 1000
|
||||
```
|
||||
|
||||
### C. Dataset Details
|
||||
|
||||
Detailed information about datasets used.
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
<style>
|
||||
.header { text-align: center; margin-bottom: 2em; }
|
||||
.authors { font-size: 1.2em; margin: 0.5em 0; }
|
||||
.date { color: #666; margin: 0.5em 0; }
|
||||
.links { margin-top: 1em; }
|
||||
.abstract { background: #f5f5f5; padding: 1.5em; border-radius: 8px; margin: 1em 0; }
|
||||
.key-insight, .insight { background: #e8f4f8; border-left: 4px solid #2196F3; padding: 1em; margin: 1em 0; }
|
||||
.definition { background: #fff3e0; border-left: 4px solid #ff9800; padding: 1em; margin: 1em 0; }
|
||||
.limitations { background: #ffebee; border-left: 4px solid #f44336; padding: 1em; margin: 1em 0; }
|
||||
.conclusion { background: #e8f5e9; border-left: 4px solid #4caf50; padding: 1.5em; margin: 1em 0; }
|
||||
.figure { text-align: center; margin: 2em 0; }
|
||||
.experiment-details, .results-table, .ablation-results { margin: 1em 0; }
|
||||
.reproducibility { background: #f5f5f5; padding: 1.5em; border-radius: 8px; margin: 2em 0; }
|
||||
</style>
|
||||
@@ -0,0 +1,201 @@
|
||||
---
|
||||
title: {{TITLE}}
|
||||
authors: {{AUTHORS}}
|
||||
date: {{DATE}}
|
||||
arxiv:
|
||||
tags: [machine-learning, deep-learning]
|
||||
---
|
||||
|
||||
# {{TITLE}}
|
||||
|
||||
**{{AUTHORS}}**
|
||||
|
||||
*{{DATE}}*
|
||||
|
||||
---
|
||||
|
||||
## Abstract
|
||||
|
||||
{{ABSTRACT}}
|
||||
|
||||
---
|
||||
|
||||
## 1. Introduction
|
||||
|
||||
Provide background and motivation for your research. Explain:
|
||||
- What problem are you addressing?
|
||||
- Why is it important?
|
||||
- What is novel about your approach?
|
||||
|
||||
### 1.1 Motivation
|
||||
|
||||
Describe the real-world context and importance of the problem.
|
||||
|
||||
### 1.2 Contributions
|
||||
|
||||
List the main contributions of your work:
|
||||
1. First contribution
|
||||
2. Second contribution
|
||||
3. Third contribution
|
||||
|
||||
---
|
||||
|
||||
## 2. Related Work
|
||||
|
||||
Survey previous research relevant to your work. Organize by:
|
||||
- Different approaches to the problem
|
||||
- Complementary methods
|
||||
- Alternative solutions
|
||||
|
||||
### 2.1 Previous Approaches
|
||||
|
||||
Discuss earlier methods and their limitations.
|
||||
|
||||
### 2.2 Recent Advances
|
||||
|
||||
Highlight recent developments in the field.
|
||||
|
||||
---
|
||||
|
||||
## 3. Background
|
||||
|
||||
Provide necessary technical background for understanding your work.
|
||||
|
||||
### 3.1 Problem Formulation
|
||||
|
||||
Formally define the problem you're solving.
|
||||
|
||||
### 3.2 Preliminaries
|
||||
|
||||
Introduce key concepts, notation, and terminology.
|
||||
|
||||
---
|
||||
|
||||
## 4. Methodology
|
||||
|
||||
Describe your approach in detail.
|
||||
|
||||
### 4.1 Overview
|
||||
|
||||
Provide a high-level description of your method.
|
||||
|
||||
### 4.2 Model Architecture
|
||||
|
||||
Detail the technical components of your system.
|
||||
|
||||
### 4.3 Training Procedure
|
||||
|
||||
Explain how the model is trained.
|
||||
|
||||
### 4.4 Implementation Details
|
||||
|
||||
Provide reproducibility information:
|
||||
- Hyperparameters
|
||||
- Hardware requirements
|
||||
- Software dependencies
|
||||
|
||||
---
|
||||
|
||||
## 5. Experiments
|
||||
|
||||
Present your experimental setup and results.
|
||||
|
||||
### 5.1 Datasets
|
||||
|
||||
Describe the datasets used for evaluation.
|
||||
|
||||
### 5.2 Evaluation Metrics
|
||||
|
||||
Define the metrics used to assess performance.
|
||||
|
||||
### 5.3 Baselines
|
||||
|
||||
List comparison methods.
|
||||
|
||||
### 5.4 Experimental Setup
|
||||
|
||||
Detail the experimental configuration.
|
||||
|
||||
---
|
||||
|
||||
## 6. Results
|
||||
|
||||
Present and analyze your findings.
|
||||
|
||||
### 6.1 Main Results
|
||||
|
||||
Report primary experimental results.
|
||||
|
||||
| Model | Dataset | Metric | Score |
|
||||
|-------|---------|--------|-------|
|
||||
| Baseline | Dataset A | Accuracy | 0.85 |
|
||||
| Ours | Dataset A | Accuracy | 0.92 |
|
||||
|
||||
### 6.2 Ablation Studies
|
||||
|
||||
Analyze the contribution of different components.
|
||||
|
||||
### 6.3 Qualitative Analysis
|
||||
|
||||
Provide examples and case studies.
|
||||
|
||||
---
|
||||
|
||||
## 7. Discussion
|
||||
|
||||
Interpret your results and discuss implications.
|
||||
|
||||
### 7.1 Analysis
|
||||
|
||||
What do the results tell us?
|
||||
|
||||
### 7.2 Limitations
|
||||
|
||||
Acknowledge limitations of your approach.
|
||||
|
||||
### 7.3 Broader Impact
|
||||
|
||||
Discuss societal implications and potential applications.
|
||||
|
||||
---
|
||||
|
||||
## 8. Conclusion
|
||||
|
||||
Summarize your work and contributions.
|
||||
|
||||
### 8.1 Summary
|
||||
|
||||
Recap the main findings.
|
||||
|
||||
### 8.2 Future Work
|
||||
|
||||
Suggest directions for future research.
|
||||
|
||||
---
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
Thank collaborators, funding sources, and computational resources.
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
1. Author A, et al. "Paper Title." Conference/Journal, Year.
|
||||
2. Author B, et al. "Another Paper." Conference/Journal, Year.
|
||||
|
||||
---
|
||||
|
||||
## Appendix
|
||||
|
||||
### A. Additional Experiments
|
||||
|
||||
Supplementary experimental results.
|
||||
|
||||
### B. Implementation Details
|
||||
|
||||
Code snippets and configuration details.
|
||||
|
||||
### C. Hyperparameters
|
||||
|
||||
Complete list of hyperparameters used.
|
||||
@@ -0,0 +1,241 @@
|
||||
---
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/huggingface-papers"
|
||||
name: hugging-face-papers
|
||||
description: Read and analyze Hugging Face paper pages or arXiv papers with markdown and papers API metadata.
|
||||
risk: unknown
|
||||
---
|
||||
|
||||
# Hugging Face Paper Pages
|
||||
|
||||
Hugging Face Paper pages (hf.co/papers) is a platform built on top of arXiv (arxiv.org), specifically for research papers in the field of artificial intelligence (AI) and computer science. Hugging Face users can submit their paper at hf.co/papers/submit, which features it on the Daily Papers feed (hf.co/papers). Each day, users can upvote papers and comment on papers. Each paper page allows authors to:
|
||||
- claim their paper (by clicking their name on the `authors` field). This makes the paper page appear on their Hugging Face profile.
|
||||
- link the associated model checkpoints, datasets and Spaces by including the HF paper or arXiv URL in the model card, dataset card or README of the Space
|
||||
- link the Github repository and/or project page URLs
|
||||
- link the HF organization. This also makes the paper page appear on the Hugging Face organization page.
|
||||
|
||||
Whenever someone mentions a HF paper or arXiv abstract/PDF URL in a model card, dataset card or README of a Space repository, the paper will be automatically indexed. Note that not all papers indexed on Hugging Face are also submitted to daily papers. The latter is more a manner of promoting a research paper. Papers can only be submitted to daily papers up until 14 days after their publication date on arXiv.
|
||||
|
||||
The Hugging Face team has built an easy-to-use API to interact with paper pages. Content of the papers can be fetched as markdown, or structured metadata can be returned such as author names, linked models/datasets/spaces, linked Github repo and project page.
|
||||
|
||||
## When to Use
|
||||
|
||||
- User shares a Hugging Face paper page URL (e.g. `https://huggingface.co/papers/2602.08025`)
|
||||
- User shares a Hugging Face markdown paper page URL (e.g. `https://huggingface.co/papers/2602.08025.md`)
|
||||
- User shares an arXiv URL (e.g. `https://arxiv.org/abs/2602.08025` or `https://arxiv.org/pdf/2602.08025`)
|
||||
- User mentions a arXiv ID (e.g. `2602.08025`)
|
||||
- User asks you to summarize, explain, or analyze an AI research paper
|
||||
|
||||
## Parsing the paper ID
|
||||
|
||||
It's recommended to parse the paper ID (arXiv ID) from whatever the user provides:
|
||||
|
||||
| Input | Paper ID |
|
||||
| --- | --- |
|
||||
| `https://huggingface.co/papers/2602.08025` | `2602.08025` |
|
||||
| `https://huggingface.co/papers/2602.08025.md` | `2602.08025` |
|
||||
| `https://arxiv.org/abs/2602.08025` | `2602.08025` |
|
||||
| `https://arxiv.org/pdf/2602.08025` | `2602.08025` |
|
||||
| `2602.08025v1` | `2602.08025v1` |
|
||||
| `2602.08025` | `2602.08025` |
|
||||
|
||||
This allows you to provide the paper ID into any of the hub API endpoints mentioned below.
|
||||
|
||||
### Fetch the paper page as markdown
|
||||
|
||||
The content of a paper can be fetched as markdown like so:
|
||||
|
||||
```bash
|
||||
curl -s "https://huggingface.co/papers/{PAPER_ID}.md"
|
||||
```
|
||||
|
||||
This should return the Hugging Face paper page as markdown. This relies on the HTML version of the paper at https://arxiv.org/html/{PAPER_ID}.
|
||||
|
||||
There are 2 exceptions:
|
||||
- Not all arXiv papers have an HTML version. If the HTML version of the paper does not exist, then the content falls back to the HTML of the Hugging Face paper page.
|
||||
- If it results in a 404, it means the paper is not yet indexed on hf.co/papers. See [Error handling](#error-handling) for info.
|
||||
|
||||
Alternatively, you can request markdown from the normal paper page URL, like so:
|
||||
|
||||
```bash
|
||||
curl -s -H "Accept: text/markdown" "https://huggingface.co/papers/{PAPER_ID}"
|
||||
```
|
||||
|
||||
### Paper Pages API Endpoints
|
||||
|
||||
All endpoints use the base URL `https://huggingface.co`.
|
||||
|
||||
#### Get structured metadata
|
||||
|
||||
Fetch the paper metadata as JSON using the Hugging Face REST API:
|
||||
|
||||
```bash
|
||||
curl -s "https://huggingface.co/api/papers/{PAPER_ID}"
|
||||
```
|
||||
|
||||
This returns structured metadata that can include:
|
||||
|
||||
- authors (names and Hugging Face usernames, in case they have claimed the paper)
|
||||
- media URLs (uploaded when submitting the paper to Daily Papers)
|
||||
- summary (abstract) and AI-generated summary
|
||||
- project page and GitHub repository
|
||||
- organization and engagement metadata (number of upvotes)
|
||||
|
||||
To find models linked to the paper, use:
|
||||
|
||||
```bash
|
||||
curl https://huggingface.co/api/models?filter=arxiv:{PAPER_ID}
|
||||
```
|
||||
|
||||
To find datasets linked to the paper, use:
|
||||
|
||||
```bash
|
||||
curl https://huggingface.co/api/datasets?filter=arxiv:{PAPER_ID}
|
||||
```
|
||||
|
||||
To find spaces linked to the paper, use:
|
||||
|
||||
```bash
|
||||
curl https://huggingface.co/api/spaces?filter=arxiv:{PAPER_ID}
|
||||
```
|
||||
|
||||
#### Claim paper authorship
|
||||
|
||||
Claim authorship of a paper for a Hugging Face user:
|
||||
|
||||
```bash
|
||||
curl "https://huggingface.co/api/settings/papers/claim" \
|
||||
--request POST \
|
||||
--header "Content-Type: application/json" \
|
||||
--header "Authorization: Bearer $HF_TOKEN" \
|
||||
--data '{
|
||||
"paperId": "{PAPER_ID}",
|
||||
"claimAuthorId": "{AUTHOR_ENTRY_ID}",
|
||||
"targetUserId": "{USER_ID}"
|
||||
}'
|
||||
```
|
||||
|
||||
- Endpoint: `POST /api/settings/papers/claim`
|
||||
- Body:
|
||||
- `paperId` (string, required): arXiv paper identifier being claimed
|
||||
- `claimAuthorId` (string): author entry on the paper being claimed, 24-char hex ID
|
||||
- `targetUserId` (string): HF user who should receive the claim, 24-char hex ID
|
||||
- Response: paper authorship claim result, including the claimed paper ID
|
||||
|
||||
#### Get daily papers
|
||||
|
||||
Fetch the Daily Papers feed:
|
||||
|
||||
```bash
|
||||
curl -s -H "Authorization: Bearer $HF_TOKEN" \
|
||||
"https://huggingface.co/api/daily_papers?p=0&limit=20&date=2017-07-21&sort=publishedAt"
|
||||
```
|
||||
|
||||
- Endpoint: `GET /api/daily_papers`
|
||||
- Query parameters:
|
||||
- `p` (integer): page number
|
||||
- `limit` (integer): number of results, between 1 and 100
|
||||
- `date` (string): RFC 3339 full-date, for example `2017-07-21`
|
||||
- `week` (string): ISO week, for example `2024-W03`
|
||||
- `month` (string): month value, for example `2024-01`
|
||||
- `submitter` (string): filter by submitter
|
||||
- `sort` (enum): `publishedAt` or `trending`
|
||||
- Response: list of daily papers
|
||||
|
||||
#### List papers
|
||||
|
||||
List arXiv papers sorted by published date:
|
||||
|
||||
```bash
|
||||
curl -s -H "Authorization: Bearer $HF_TOKEN" \
|
||||
"https://huggingface.co/api/papers?cursor={CURSOR}&limit=20"
|
||||
```
|
||||
|
||||
- Endpoint: `GET /api/papers`
|
||||
- Query parameters:
|
||||
- `cursor` (string): pagination cursor
|
||||
- `limit` (integer): number of results, between 1 and 100
|
||||
- Response: list of papers
|
||||
|
||||
#### Search papers
|
||||
|
||||
Perform hybrid semantic and full-text search on papers:
|
||||
|
||||
```bash
|
||||
curl -s -H "Authorization: Bearer $HF_TOKEN" \
|
||||
"https://huggingface.co/api/papers/search?q=vision+language&limit=20"
|
||||
```
|
||||
|
||||
This searches over the paper title, authors, and content.
|
||||
|
||||
- Endpoint: `GET /api/papers/search`
|
||||
- Query parameters:
|
||||
- `q` (string): search query, max length 250
|
||||
- `limit` (integer): number of results, between 1 and 120
|
||||
- Response: matching papers
|
||||
|
||||
#### Index a paper
|
||||
|
||||
Insert a paper from arXiv by ID. If the paper is already indexed, only its authors can re-index it:
|
||||
|
||||
```bash
|
||||
curl "https://huggingface.co/api/papers/index" \
|
||||
--request POST \
|
||||
--header "Content-Type: application/json" \
|
||||
--header "Authorization: Bearer $HF_TOKEN" \
|
||||
--data '{
|
||||
"arxivId": "{ARXIV_ID}"
|
||||
}'
|
||||
```
|
||||
|
||||
- Endpoint: `POST /api/papers/index`
|
||||
- Body:
|
||||
- `arxivId` (string, required): arXiv ID to index, for example `2301.00001`
|
||||
- Pattern: `^\d{4}\.\d{4,5}$`
|
||||
- Response: empty JSON object on success
|
||||
|
||||
#### Update paper links
|
||||
|
||||
Update the project page, GitHub repository, or submitting organization for a paper. The requester must be the paper author, the Daily Papers submitter, or a papers admin:
|
||||
|
||||
```bash
|
||||
curl "https://huggingface.co/api/papers/{PAPER_OBJECT_ID}/links" \
|
||||
--request POST \
|
||||
--header "Content-Type: application/json" \
|
||||
--header "Authorization: Bearer $HF_TOKEN" \
|
||||
--data '{
|
||||
"projectPage": "https://example.com",
|
||||
"githubRepo": "https://github.com/org/repo",
|
||||
"organizationId": "{ORGANIZATION_ID}"
|
||||
}'
|
||||
```
|
||||
|
||||
- Endpoint: `POST /api/papers/{paperId}/links`
|
||||
- Path parameters:
|
||||
- `paperId` (string, required): Hugging Face paper object ID
|
||||
- Body:
|
||||
- `githubRepo` (string, nullable): GitHub repository URL
|
||||
- `organizationId` (string, nullable): organization ID, 24-char hex ID
|
||||
- `projectPage` (string, nullable): project page URL
|
||||
- Response: empty JSON object on success
|
||||
|
||||
## Error Handling
|
||||
|
||||
- **404 on `https://huggingface.co/papers/{PAPER_ID}` or `md` endpoint**: the paper is not indexed on Hugging Face paper pages yet.
|
||||
- **404 on `/api/papers/{PAPER_ID}`**: the paper may not be indexed on Hugging Face paper pages yet.
|
||||
- **Paper ID not found**: verify the extracted arXiv ID, including any version suffix
|
||||
|
||||
### Fallbacks
|
||||
|
||||
If the Hugging Face paper page does not contain enough detail for the user's question:
|
||||
|
||||
- Check the regular paper page at `https://huggingface.co/papers/{PAPER_ID}`
|
||||
- Fall back to the arXiv page or PDF for the original source:
|
||||
- `https://arxiv.org/abs/{PAPER_ID}`
|
||||
- `https://arxiv.org/pdf/{PAPER_ID}`
|
||||
|
||||
## Notes
|
||||
|
||||
- No authentication is required for public paper pages.
|
||||
- Write endpoints such as claim authorship, index paper, and update paper links require `Authorization: Bearer $HF_TOKEN`.
|
||||
- Prefer the `.md` endpoint for reliable machine-readable output.
|
||||
- Prefer `/api/papers/{PAPER_ID}` when you need structured JSON fields instead of page markdown.
|
||||
@@ -0,0 +1,117 @@
|
||||
---
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/huggingface-trackio"
|
||||
name: hugging-face-trackio
|
||||
description: Track ML experiments with Trackio using Python logging, alerts, and CLI metric retrieval.
|
||||
risk: unknown
|
||||
---
|
||||
|
||||
# Trackio - Experiment Tracking for ML Training
|
||||
|
||||
Trackio is an experiment tracking library for logging and visualizing ML training metrics. It syncs to Hugging Face Spaces for real-time monitoring dashboards.
|
||||
|
||||
## Three Interfaces
|
||||
|
||||
| Task | Interface | Reference |
|
||||
|------|-----------|-----------|
|
||||
| **Logging metrics** during training | Python API | [references/logging_metrics.md](references/logging_metrics.md) |
|
||||
| **Firing alerts** for training diagnostics | Python API | [references/alerts.md](references/alerts.md) |
|
||||
| **Retrieving metrics & alerts** after/during training | CLI | [references/retrieving_metrics.md](references/retrieving_metrics.md) |
|
||||
|
||||
## When to Use Each
|
||||
|
||||
### Python API → Logging
|
||||
|
||||
Use `import trackio` in your training scripts to log metrics:
|
||||
|
||||
- Initialize tracking with `trackio.init()`
|
||||
- Log metrics with `trackio.log()` or use TRL's `report_to="trackio"`
|
||||
- Finalize with `trackio.finish()`
|
||||
|
||||
**Key concept**: For remote/cloud training, pass `space_id` — metrics sync to a Space dashboard so they persist after the instance terminates.
|
||||
|
||||
→ See [references/logging_metrics.md](references/logging_metrics.md) for setup, TRL integration, and configuration options.
|
||||
|
||||
### Python API → Alerts
|
||||
|
||||
Insert `trackio.alert()` calls in training code to flag important events — like inserting print statements for debugging, but structured and queryable:
|
||||
|
||||
- `trackio.alert(title="...", level=trackio.AlertLevel.WARN)` — fire an alert
|
||||
- Three severity levels: `INFO`, `WARN`, `ERROR`
|
||||
- Alerts are printed to terminal, stored in the database, shown in the dashboard, and optionally sent to webhooks (Slack/Discord)
|
||||
|
||||
**Key concept for LLM agents**: Alerts are the primary mechanism for autonomous experiment iteration. An agent should insert alerts into training code for diagnostic conditions (loss spikes, NaN gradients, low accuracy, training stalls). Since alerts are printed to the terminal, an agent that is watching the training script's output will see them automatically. For background or detached runs, the agent can poll via CLI instead.
|
||||
|
||||
→ See [references/alerts.md](references/alerts.md) for the full alerts API, webhook setup, and autonomous agent workflows.
|
||||
|
||||
### CLI → Retrieving
|
||||
|
||||
Use the `trackio` command to query logged metrics and alerts:
|
||||
|
||||
- `trackio list projects/runs/metrics` — discover what's available
|
||||
- `trackio get project/run/metric` — retrieve summaries and values
|
||||
- `trackio list alerts --project <name> --json` — retrieve alerts
|
||||
- `trackio show` — launch the dashboard
|
||||
- `trackio sync` — sync to HF Space
|
||||
|
||||
**Key concept**: Add `--json` for programmatic output suitable for automation and LLM agents.
|
||||
|
||||
→ See [references/retrieving_metrics.md](references/retrieving_metrics.md) for all commands, workflows, and JSON output formats.
|
||||
|
||||
## Minimal Logging Setup
|
||||
|
||||
```python
|
||||
import trackio
|
||||
|
||||
trackio.init(project="my-project", space_id="username/trackio")
|
||||
trackio.log({"loss": 0.1, "accuracy": 0.9})
|
||||
trackio.log({"loss": 0.09, "accuracy": 0.91})
|
||||
trackio.finish()
|
||||
```
|
||||
|
||||
### Minimal Retrieval
|
||||
|
||||
```bash
|
||||
trackio list projects --json
|
||||
trackio get metric --project my-project --run my-run --metric loss --json
|
||||
```
|
||||
|
||||
## Autonomous ML Experiment Workflow
|
||||
|
||||
When running experiments autonomously as an LLM agent, the recommended workflow is:
|
||||
|
||||
1. **Set up training with alerts** — insert `trackio.alert()` calls for diagnostic conditions
|
||||
2. **Launch training** — run the script in the background
|
||||
3. **Poll for alerts** — use `trackio list alerts --project <name> --json --since <timestamp>` to check for new alerts
|
||||
4. **Read metrics** — use `trackio get metric ...` to inspect specific values
|
||||
5. **Iterate** — based on alerts and metrics, stop the run, adjust hyperparameters, and launch a new run
|
||||
|
||||
```python
|
||||
import trackio
|
||||
|
||||
trackio.init(project="my-project", config={"lr": 1e-4})
|
||||
|
||||
for step in range(num_steps):
|
||||
loss = train_step()
|
||||
trackio.log({"loss": loss, "step": step})
|
||||
|
||||
if step > 100 and loss > 5.0:
|
||||
trackio.alert(
|
||||
title="Loss divergence",
|
||||
text=f"Loss {loss:.4f} still high after {step} steps",
|
||||
level=trackio.AlertLevel.ERROR,
|
||||
)
|
||||
if step > 0 and abs(loss) < 1e-8:
|
||||
trackio.alert(
|
||||
title="Vanishing loss",
|
||||
text="Loss near zero — possible gradient collapse",
|
||||
level=trackio.AlertLevel.WARN,
|
||||
)
|
||||
|
||||
trackio.finish()
|
||||
```
|
||||
|
||||
Then poll from a separate terminal/process:
|
||||
|
||||
```bash
|
||||
trackio list alerts --project my-project --json --since "2025-01-01T00:00:00"
|
||||
```
|
||||
@@ -0,0 +1,196 @@
|
||||
# Trackio Alerts
|
||||
|
||||
Alerts let you flag important training events directly from code. They are the primary mechanism for LLM agents to diagnose runs and iterate autonomously on ML experiments.
|
||||
|
||||
Alerts are printed to the terminal, stored in the database, displayed in the dashboard, and optionally sent to webhooks (Slack/Discord).
|
||||
|
||||
## Core API
|
||||
|
||||
### trackio.alert()
|
||||
|
||||
```python
|
||||
trackio.alert(
|
||||
title="Loss divergence", # Short title (required)
|
||||
text="Loss 5.2 still high after 200 steps", # Detailed description (optional)
|
||||
level=trackio.AlertLevel.WARN, # INFO, WARN, or ERROR (default: WARN)
|
||||
webhook_url="https://hooks.slack.com/...", # Per-alert webhook override (optional)
|
||||
)
|
||||
```
|
||||
|
||||
### Alert Levels
|
||||
|
||||
| Level | Usage |
|
||||
|-------|-------|
|
||||
| `trackio.AlertLevel.INFO` | Informational milestones (checkpoints saved, eval completed) |
|
||||
| `trackio.AlertLevel.WARN` | Potential issues (loss plateau, low accuracy, high gradient norm) |
|
||||
| `trackio.AlertLevel.ERROR` | Critical failures (NaN loss, divergence, OOM) |
|
||||
|
||||
### Webhook Support
|
||||
|
||||
Set a global webhook URL via `trackio.init()` or the `TRACKIO_WEBHOOK_URL` environment variable. Alerts are auto-formatted for Slack and Discord URLs.
|
||||
|
||||
```python
|
||||
trackio.init(
|
||||
project="my-project",
|
||||
webhook_url="https://hooks.slack.com/services/...",
|
||||
webhook_min_level=trackio.AlertLevel.WARN, # Only send WARN+ to webhook
|
||||
)
|
||||
```
|
||||
|
||||
Per-alert override:
|
||||
|
||||
```python
|
||||
trackio.alert(
|
||||
title="Critical failure",
|
||||
level=trackio.AlertLevel.ERROR,
|
||||
webhook_url="https://hooks.slack.com/services/...", # Overrides global URL
|
||||
)
|
||||
```
|
||||
|
||||
Environment variables:
|
||||
- `TRACKIO_WEBHOOK_URL` — global webhook URL
|
||||
- `TRACKIO_WEBHOOK_MIN_LEVEL` — minimum level for webhook delivery (`info`, `warn`, `error`)
|
||||
|
||||
## Retrieving Alerts (CLI)
|
||||
|
||||
```bash
|
||||
# List all alerts for a project
|
||||
trackio list alerts --project my-project --json
|
||||
|
||||
# Filter by run or level
|
||||
trackio list alerts --project my-project --run my-run --level error --json
|
||||
|
||||
# Poll for new alerts since a timestamp (efficient for agents)
|
||||
trackio list alerts --project my-project --json --since "2025-06-01T12:00:00"
|
||||
```
|
||||
|
||||
### JSON Output Structure
|
||||
|
||||
```json
|
||||
{
|
||||
"project": "my-project",
|
||||
"run": null,
|
||||
"level": null,
|
||||
"since": "2025-06-01T12:00:00",
|
||||
"alerts": [
|
||||
{
|
||||
"run": "run-name",
|
||||
"title": "Loss divergence",
|
||||
"text": "Loss 5.2 still high after 200 steps",
|
||||
"level": "warn",
|
||||
"step": 200,
|
||||
"timestamp": "2025-06-01T12:05:30"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Autonomous Agent Workflow
|
||||
|
||||
The recommended pattern for an LLM agent running ML experiments:
|
||||
|
||||
### 1. Insert Alerts Into Training Code
|
||||
|
||||
Add diagnostic `trackio.alert()` calls for conditions the agent should react to:
|
||||
|
||||
```python
|
||||
import trackio
|
||||
|
||||
trackio.init(project="hyperparam-sweep", config={"lr": lr, "batch_size": bs})
|
||||
|
||||
for step in range(num_steps):
|
||||
loss = train_step()
|
||||
trackio.log({"loss": loss, "step": step})
|
||||
|
||||
if step > 200 and loss > 5.0:
|
||||
trackio.alert(
|
||||
title="Loss divergence",
|
||||
text=f"Loss {loss:.4f} still above 5.0 after {step} steps — learning rate may be too high",
|
||||
level=trackio.AlertLevel.ERROR,
|
||||
)
|
||||
|
||||
if step > 500 and loss_delta < 0.001:
|
||||
trackio.alert(
|
||||
title="Training stall",
|
||||
text=f"Loss barely changed over last 100 steps (delta={loss_delta:.6f})",
|
||||
level=trackio.AlertLevel.WARN,
|
||||
)
|
||||
|
||||
if math.isnan(loss):
|
||||
trackio.alert(
|
||||
title="NaN loss",
|
||||
text="Loss became NaN — training is broken",
|
||||
level=trackio.AlertLevel.ERROR,
|
||||
)
|
||||
break
|
||||
|
||||
trackio.finish()
|
||||
```
|
||||
|
||||
### 2. Monitor Alerts
|
||||
|
||||
Alerts are automatically printed to the terminal when fired. If the agent is watching the training script's output (e.g. running in the foreground or tailing logs), it will see alerts immediately — no polling needed.
|
||||
|
||||
For background or detached runs, poll for alerts via CLI:
|
||||
|
||||
```bash
|
||||
# Poll for alerts (run periodically)
|
||||
trackio list alerts --project hyperparam-sweep --json --since "2025-06-01T00:00:00"
|
||||
```
|
||||
|
||||
### 3. Inspect Metrics Around the Alert
|
||||
|
||||
When an alert fires, use `trackio get snapshot` to see all metrics at that point:
|
||||
|
||||
```bash
|
||||
# Alert fired at step 200 — get all metrics in a ±5 step window
|
||||
trackio get snapshot --project hyperparam-sweep --run run-1 --around 200 --window 5 --json
|
||||
|
||||
# Or inspect a single metric around the alert's timestamp
|
||||
trackio get metric --project hyperparam-sweep --run run-1 --metric loss --around 200 --window 10 --json
|
||||
```
|
||||
|
||||
### 4. React and Iterate
|
||||
|
||||
Based on alerts:
|
||||
- **ERROR alerts** → stop the run, adjust hyperparameters, relaunch
|
||||
- **WARN alerts** → inspect metrics with `trackio get snapshot ...`, decide whether to intervene
|
||||
- **INFO alerts** → note progress, continue monitoring
|
||||
|
||||
### 5. Compare Across Runs
|
||||
|
||||
```bash
|
||||
# Check metrics from previous runs
|
||||
trackio get run --project hyperparam-sweep --run run-1 --json
|
||||
trackio get metric --project hyperparam-sweep --run run-1 --metric loss --json
|
||||
|
||||
# Launch new run with adjusted config
|
||||
python train.py --lr 5e-5
|
||||
```
|
||||
|
||||
## Using Alerts with Transformers / TRL
|
||||
|
||||
When using `report_to="trackio"`, you don't control the training loop directly. Use a `TrainerCallback` to fire alerts:
|
||||
|
||||
```python
|
||||
from transformers import TrainerCallback
|
||||
|
||||
class AlertCallback(TrainerCallback):
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if "trackio" not in args.report_to:
|
||||
return
|
||||
if logs and "loss" in logs:
|
||||
if logs["loss"] > 5.0 and state.global_step > 100:
|
||||
trackio.alert(
|
||||
title="High loss",
|
||||
text=f"Loss {logs['loss']:.4f} at step {state.global_step}",
|
||||
level=trackio.AlertLevel.ERROR,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=SFTConfig(output_dir="./out", report_to="trackio"),
|
||||
callbacks=[AlertCallback()],
|
||||
...
|
||||
)
|
||||
```
|
||||
@@ -0,0 +1,206 @@
|
||||
# Logging Metrics with Trackio
|
||||
|
||||
**Trackio** is a lightweight, free experiment tracking library from Hugging Face. It provides a wandb-compatible API for logging metrics with local-first design.
|
||||
|
||||
- **GitHub**: [gradio-app/trackio](https://github.com/gradio-app/trackio)
|
||||
- **Docs**: [huggingface.co/docs/trackio](https://huggingface.co/docs/trackio/index)
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install trackio
|
||||
# or
|
||||
uv pip install trackio
|
||||
```
|
||||
|
||||
## Core API
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
import trackio
|
||||
|
||||
# Initialize a run
|
||||
trackio.init(
|
||||
project="my-project",
|
||||
config={"learning_rate": 0.001, "epochs": 10}
|
||||
)
|
||||
|
||||
# Log metrics during training
|
||||
for epoch in range(10):
|
||||
loss = train_epoch()
|
||||
trackio.log({"loss": loss, "epoch": epoch})
|
||||
|
||||
# Finalize the run
|
||||
trackio.finish()
|
||||
```
|
||||
|
||||
### Key Functions
|
||||
|
||||
| Function | Purpose |
|
||||
|----------|---------|
|
||||
| `trackio.init(...)` | Start a new tracking run |
|
||||
| `trackio.log(dict)` | Log metrics (called repeatedly during training) |
|
||||
| `trackio.finish()` | Finalize run and ensure all metrics are saved |
|
||||
| `trackio.show()` | Launch the local dashboard |
|
||||
| `trackio.sync(...)` | Sync local project to HF Space |
|
||||
|
||||
## trackio.init() Parameters
|
||||
|
||||
```python
|
||||
trackio.init(
|
||||
project="my-project", # Project name (groups runs together)
|
||||
name="run-name", # Optional: name for this specific run
|
||||
config={...}, # Hyperparameters and config to log
|
||||
space_id="username/trackio", # Optional: sync to HF Space for remote dashboard
|
||||
group="experiment-group", # Optional: group related runs
|
||||
)
|
||||
```
|
||||
|
||||
## Local vs Remote Dashboard
|
||||
|
||||
### Local (Default)
|
||||
|
||||
By default, trackio stores metrics in a local SQLite database and runs the dashboard locally:
|
||||
|
||||
```python
|
||||
trackio.init(project="my-project")
|
||||
# ... training ...
|
||||
trackio.finish()
|
||||
|
||||
# Launch local dashboard
|
||||
trackio.show()
|
||||
```
|
||||
|
||||
Or from terminal:
|
||||
```bash
|
||||
trackio show --project my-project
|
||||
```
|
||||
|
||||
### Remote (HF Space)
|
||||
|
||||
Pass `space_id` to sync metrics to a Hugging Face Space for persistent, shareable dashboards:
|
||||
|
||||
```python
|
||||
trackio.init(
|
||||
project="my-project",
|
||||
space_id="username/trackio" # Auto-creates Space if it doesn't exist
|
||||
)
|
||||
```
|
||||
|
||||
⚠️ **For remote training** (cloud GPUs, HF Jobs, etc.): Always use `space_id` since local storage is lost when the instance terminates.
|
||||
|
||||
### Sync Local to Remote
|
||||
|
||||
Sync existing local projects to a Space:
|
||||
|
||||
```python
|
||||
trackio.sync(project="my-project", space_id="username/my-experiments")
|
||||
```
|
||||
|
||||
## wandb Compatibility
|
||||
|
||||
Trackio is API-compatible with wandb. Drop-in replacement:
|
||||
|
||||
```python
|
||||
import trackio as wandb
|
||||
|
||||
wandb.init(project="my-project")
|
||||
wandb.log({"loss": 0.5})
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## TRL Integration
|
||||
|
||||
When using TRL trainers, set `report_to="trackio"` for automatic metric logging:
|
||||
|
||||
```python
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
import trackio
|
||||
|
||||
trackio.init(
|
||||
project="sft-training",
|
||||
space_id="username/trackio",
|
||||
config={"model": "Qwen/Qwen2.5-0.5B", "dataset": "trl-lib/Capybara"}
|
||||
)
|
||||
|
||||
config = SFTConfig(
|
||||
output_dir="./output",
|
||||
report_to="trackio", # Automatic metric logging
|
||||
# ... other config
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(model=model, args=config, ...)
|
||||
trainer.train()
|
||||
trackio.finish()
|
||||
```
|
||||
|
||||
## What Gets Logged
|
||||
|
||||
With TRL/Transformers integration, trackio automatically captures:
|
||||
- Training loss
|
||||
- Learning rate
|
||||
- Eval metrics
|
||||
- Training throughput
|
||||
|
||||
For manual logging, log any numeric metrics:
|
||||
|
||||
```python
|
||||
trackio.log({
|
||||
"train_loss": 0.5,
|
||||
"train_accuracy": 0.85,
|
||||
"val_loss": 0.4,
|
||||
"val_accuracy": 0.88,
|
||||
"epoch": 1
|
||||
})
|
||||
```
|
||||
|
||||
## Grouping Runs
|
||||
|
||||
Use `group` to organize related experiments in the dashboard sidebar:
|
||||
|
||||
```python
|
||||
# Group by experiment type
|
||||
trackio.init(project="my-project", name="baseline-v1", group="baseline")
|
||||
trackio.init(project="my-project", name="augmented-v1", group="augmented")
|
||||
|
||||
# Group by hyperparameter
|
||||
trackio.init(project="hyperparam-sweep", name="lr-0.001", group="lr_0.001")
|
||||
trackio.init(project="hyperparam-sweep", name="lr-0.01", group="lr_0.01")
|
||||
```
|
||||
|
||||
## Configuration Best Practices
|
||||
|
||||
Keep config minimal — only log what's useful for comparing runs:
|
||||
|
||||
```python
|
||||
trackio.init(
|
||||
project="qwen-sft-capybara",
|
||||
name="baseline-lr2e5",
|
||||
config={
|
||||
"model": "Qwen/Qwen2.5-0.5B",
|
||||
"dataset": "trl-lib/Capybara",
|
||||
"learning_rate": 2e-5,
|
||||
"num_epochs": 3,
|
||||
"batch_size": 8,
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Embedding Dashboards
|
||||
|
||||
Embed Space dashboards in websites with query parameters:
|
||||
|
||||
```html
|
||||
<iframe
|
||||
src="https://username-trackio.hf.space/?project=my-project&metrics=train_loss,val_loss&sidebar=hidden"
|
||||
style="width:1600px; height:500px; border:0;">
|
||||
</iframe>
|
||||
```
|
||||
|
||||
Query parameters:
|
||||
- `project`: Filter to specific project
|
||||
- `metrics`: Comma-separated metric names to show
|
||||
- `sidebar`: `hidden` or `collapsed`
|
||||
- `smoothing`: 0-20 (smoothing slider value)
|
||||
- `xmin`, `xmax`: X-axis limits
|
||||
@@ -0,0 +1,251 @@
|
||||
# Retrieving Metrics with Trackio CLI
|
||||
|
||||
The `trackio` CLI provides direct terminal access to query Trackio experiment tracking data locally without needing to start the MCP server.
|
||||
|
||||
## Quick Command Reference
|
||||
|
||||
| Task | Command |
|
||||
|------|---------|
|
||||
| List projects | `trackio list projects` |
|
||||
| List runs | `trackio list runs --project <name>` |
|
||||
| List metrics | `trackio list metrics --project <name> --run <name>` |
|
||||
| List system metrics | `trackio list system-metrics --project <name> --run <name>` |
|
||||
| List alerts | `trackio list alerts --project <name> [--run <name>] [--level <level>] [--since <timestamp>]` |
|
||||
| Get project summary | `trackio get project --project <name>` |
|
||||
| Get run summary | `trackio get run --project <name> --run <name>` |
|
||||
| Get metric values | `trackio get metric --project <name> --run <name> --metric <name>` |
|
||||
| Get metric at step | `trackio get metric ... --metric <name> --step <N>` |
|
||||
| Get metric around step | `trackio get metric ... --metric <name> --around <N> --window <W>` |
|
||||
| Get all metrics snapshot | `trackio get snapshot --project <name> --run <name> --step <N>` |
|
||||
| Get system metrics | `trackio get system-metric --project <name> --run <name>` |
|
||||
| Show dashboard | `trackio show [--project <name>]` |
|
||||
| Sync to Space | `trackio sync --project <name> --space-id <space_id>` |
|
||||
|
||||
## Core Commands
|
||||
|
||||
### List Commands
|
||||
|
||||
```bash
|
||||
trackio list projects # List all projects
|
||||
trackio list projects --json # JSON output
|
||||
|
||||
trackio list runs --project <name> # List runs in project
|
||||
trackio list runs --project <name> --json # JSON output
|
||||
|
||||
trackio list metrics --project <name> --run <name> # List metrics for run
|
||||
trackio list metrics --project <name> --run <name> --json
|
||||
|
||||
trackio list system-metrics --project <name> --run <name> # List system metrics
|
||||
trackio list system-metrics --project <name> --run <name> --json
|
||||
|
||||
trackio list alerts --project <name> # List alerts
|
||||
trackio list alerts --project <name> --run <name> --json # Filter by run
|
||||
trackio list alerts --project <name> --level error --json # Filter by level
|
||||
trackio list alerts --project <name> --json --since <ts> # Poll since timestamp
|
||||
```
|
||||
|
||||
### Get Commands
|
||||
|
||||
```bash
|
||||
trackio get project --project <name> # Project summary
|
||||
trackio get project --project <name> --json # JSON output
|
||||
|
||||
trackio get run --project <name> --run <name> # Run summary
|
||||
trackio get run --project <name> --run <name> --json
|
||||
|
||||
trackio get metric --project <name> --run <name> --metric <name> # Metric values
|
||||
trackio get metric --project <name> --run <name> --metric <name> --json
|
||||
trackio get metric ... --metric <name> --step 200 # At exact step
|
||||
trackio get metric ... --metric <name> --around 200 --window 10 # ±10 steps
|
||||
trackio get metric ... --metric <name> --at-time <ts> --window 60 # ±60 seconds
|
||||
|
||||
trackio get snapshot --project <name> --run <name> --step 200 --json # All metrics at step
|
||||
trackio get snapshot --project <name> --run <name> --around 200 --window 5 --json # Window
|
||||
trackio get snapshot --project <name> --run <name> --at-time <ts> --window 60 --json
|
||||
|
||||
trackio get system-metric --project <name> --run <name> # All system metrics
|
||||
trackio get system-metric --project <name> --run <name> --metric <name> # Specific metric
|
||||
trackio get system-metric --project <name> --run <name> --json
|
||||
```
|
||||
|
||||
### Dashboard Commands
|
||||
|
||||
```bash
|
||||
trackio show # Launch dashboard
|
||||
trackio show --project <name> # Load specific project
|
||||
trackio show --theme <theme> # Custom theme
|
||||
trackio show --mcp-server # Enable MCP server
|
||||
trackio show --color-palette "#FF0000,#00FF00" # Custom colors
|
||||
```
|
||||
|
||||
### Sync Commands
|
||||
|
||||
```bash
|
||||
trackio sync --project <name> --space-id <space_id> # Sync to HF Space
|
||||
trackio sync --project <name> --space-id <space_id> --private # Private space
|
||||
trackio sync --project <name> --space-id <space_id> --force # Overwrite
|
||||
```
|
||||
|
||||
## Output Formats
|
||||
|
||||
All `list` and `get` commands support two output formats:
|
||||
|
||||
- **Human-readable** (default): Formatted text for terminal viewing
|
||||
- **JSON** (with `--json` flag): Structured JSON for programmatic use
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Discover Projects and Runs
|
||||
|
||||
```bash
|
||||
# List all available projects
|
||||
trackio list projects
|
||||
|
||||
# List runs in a project
|
||||
trackio list runs --project my-project
|
||||
|
||||
# Get project overview
|
||||
trackio get project --project my-project --json
|
||||
```
|
||||
|
||||
### Inspect Run Details
|
||||
|
||||
```bash
|
||||
# Get run summary with all metrics
|
||||
trackio get run --project my-project --run my-run --json
|
||||
|
||||
# List available metrics
|
||||
trackio list metrics --project my-project --run my-run
|
||||
|
||||
# Get specific metric values
|
||||
trackio get metric --project my-project --run my-run --metric loss --json
|
||||
```
|
||||
|
||||
### Query System Metrics
|
||||
|
||||
```bash
|
||||
# List system metrics (GPU, etc.)
|
||||
trackio list system-metrics --project my-project --run my-run
|
||||
|
||||
# Get all system metric data
|
||||
trackio get system-metric --project my-project --run my-run --json
|
||||
|
||||
# Get specific system metric
|
||||
trackio get system-metric --project my-project --run my-run --metric gpu_utilization --json
|
||||
```
|
||||
|
||||
### Automation Scripts
|
||||
|
||||
```bash
|
||||
# Extract latest metric value
|
||||
LATEST_LOSS=$(trackio get metric --project my-project --run my-run --metric loss --json | jq -r '.values[-1].value')
|
||||
|
||||
# Export run summary to file
|
||||
trackio get run --project my-project --run my-run --json > run_summary.json
|
||||
|
||||
# Filter runs with jq
|
||||
trackio list runs --project my-project --json | jq '.runs[] | select(startswith("train"))'
|
||||
```
|
||||
|
||||
### LLM Agent Workflow
|
||||
|
||||
```bash
|
||||
# 1. Discover available projects
|
||||
trackio list projects --json
|
||||
|
||||
# 2. Explore project structure
|
||||
trackio get project --project my-project --json
|
||||
|
||||
# 3. Inspect specific run
|
||||
trackio get run --project my-project --run my-run --json
|
||||
|
||||
# 4. Query metric values
|
||||
trackio get metric --project my-project --run my-run --metric accuracy --json
|
||||
|
||||
# 5. Poll for alerts (use --since for efficient incremental polling)
|
||||
trackio list alerts --project my-project --json --since "2025-06-01T00:00:00"
|
||||
|
||||
# 6. When an alert fires at step N, get all metrics around that point
|
||||
trackio get snapshot --project my-project --run my-run --around 200 --window 5 --json
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
Commands validate inputs and return clear errors:
|
||||
|
||||
- Missing project: `Error: Project '<name>' not found.`
|
||||
- Missing run: `Error: Run '<name>' not found in project '<project>'.`
|
||||
- Missing metric: `Error: Metric '<name>' not found in run '<run>' of project '<project>'.`
|
||||
|
||||
All errors exit with non-zero status code and write to stderr.
|
||||
|
||||
## Key Options
|
||||
|
||||
- `--project`: Project name (required for most commands)
|
||||
- `--run`: Run name (required for run-specific commands)
|
||||
- `--metric`: Metric name (required for metric-specific commands)
|
||||
- `--json`: Output in JSON format instead of human-readable
|
||||
- `--step`: Exact step filter (for `get metric`, `get snapshot`)
|
||||
- `--around`: Center step for window filter (for `get metric`, `get snapshot`)
|
||||
- `--at-time`: Center ISO timestamp for window filter (for `get metric`, `get snapshot`)
|
||||
- `--window`: Window size: ±steps for `--around`, ±seconds for `--at-time` (default: 10)
|
||||
- `--level`: Alert level filter (`info`, `warn`, `error`) (for `list alerts`)
|
||||
- `--since`: ISO timestamp to filter alerts after (for `list alerts`)
|
||||
- `--theme`: Dashboard theme (for `show` command)
|
||||
- `--mcp-server`: Enable MCP server mode (for `show` command)
|
||||
- `--color-palette`: Comma-separated hex colors (for `show` command)
|
||||
- `--private`: Create private Space (for `sync` command)
|
||||
- `--force`: Overwrite existing database (for `sync` command)
|
||||
|
||||
## JSON Output Structure
|
||||
|
||||
### List Projects
|
||||
```json
|
||||
{"projects": ["project1", "project2"]}
|
||||
```
|
||||
|
||||
### List Runs
|
||||
```json
|
||||
{"project": "my-project", "runs": ["run1", "run2"]}
|
||||
```
|
||||
|
||||
### Project Summary
|
||||
```json
|
||||
{
|
||||
"project": "my-project",
|
||||
"num_runs": 3,
|
||||
"runs": ["run1", "run2", "run3"],
|
||||
"last_activity": 100
|
||||
}
|
||||
```
|
||||
|
||||
### Run Summary
|
||||
```json
|
||||
{
|
||||
"project": "my-project",
|
||||
"run": "my-run",
|
||||
"num_logs": 50,
|
||||
"metrics": ["loss", "accuracy"],
|
||||
"config": {"learning_rate": 0.001},
|
||||
"last_step": 49
|
||||
}
|
||||
```
|
||||
|
||||
### Metric Values
|
||||
```json
|
||||
{
|
||||
"project": "my-project",
|
||||
"run": "my-run",
|
||||
"metric": "loss",
|
||||
"values": [
|
||||
{"step": 0, "timestamp": "2024-01-01T00:00:00", "value": 0.5},
|
||||
{"step": 1, "timestamp": "2024-01-01T00:01:00", "value": 0.4}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- **Complete CLI documentation**: See [docs/source/cli_commands.md](docs/source/cli_commands.md)
|
||||
- **API and MCP Server**: See [docs/source/api_mcp_server.md](docs/source/api_mcp_server.md)
|
||||
|
||||
@@ -0,0 +1,595 @@
|
||||
---
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/huggingface-vision-trainer"
|
||||
name: hugging-face-vision-trainer
|
||||
description: Train or fine-tune vision models on Hugging Face Jobs for detection, classification, and SAM or SAM2 segmentation.
|
||||
risk: unknown
|
||||
---
|
||||
|
||||
# Vision Model Training on Hugging Face Jobs
|
||||
|
||||
Train object detection, image classification, and SAM/SAM2 segmentation models on managed cloud GPUs. No local GPU setup required—results are automatically saved to the Hugging Face Hub.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when users want to:
|
||||
- Fine-tune object detection models (D-FINE, RT-DETR v2, DETR, YOLOS) on cloud GPUs or local
|
||||
- Fine-tune image classification models (timm: MobileNetV3, MobileViT, ResNet, ViT/DINOv3, or any Transformers classifier) on cloud GPUs or local
|
||||
- Fine-tune SAM or SAM2 models for segmentation / image matting using bbox or point prompts
|
||||
- Train bounding-box detectors on custom datasets
|
||||
- Train image classifiers on custom datasets
|
||||
- Train segmentation models on custom mask datasets with prompts
|
||||
- Run vision training jobs on Hugging Face Jobs infrastructure
|
||||
- Ensure trained vision models are permanently saved to the Hub
|
||||
|
||||
## Related Skills
|
||||
|
||||
- **`hugging-face-jobs`** — General HF Jobs infrastructure: token authentication, hardware flavors, timeout management, cost estimation, secrets, environment variables, scheduled jobs, and result persistence. **Refer to the Jobs skill for any non-training-specific Jobs questions** (e.g., "how do secrets work?", "what hardware is available?", "how do I pass tokens?").
|
||||
- **`hugging-face-model-trainer`** — TRL-based language model training (SFT, DPO, GRPO). Use that skill for text/language model fine-tuning.
|
||||
|
||||
## Local Script Execution
|
||||
|
||||
Helper scripts use PEP 723 inline dependencies. Run them with `uv run`:
|
||||
```bash
|
||||
uv run scripts/dataset_inspector.py --dataset username/dataset-name --split train
|
||||
uv run scripts/estimate_cost.py --help
|
||||
```
|
||||
|
||||
## Prerequisites Checklist
|
||||
|
||||
Before starting any training job, verify:
|
||||
|
||||
### Account & Authentication
|
||||
- Hugging Face Account with [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan (Jobs require paid plan)
|
||||
- Authenticated login: Check with `hf_whoami()` (tool) or `hf auth whoami` (terminal)
|
||||
- Token has **write** permissions
|
||||
- **MUST pass token in job secrets** — see directive #3 below for syntax (MCP tool vs Python API)
|
||||
|
||||
### Dataset Requirements — Object Detection
|
||||
- Dataset must exist on Hub
|
||||
- Annotations must use the `objects` column with `bbox`, `category` (and optionally `area`) sub-fields
|
||||
- Bboxes can be in **xywh (COCO)** or **xyxy (Pascal VOC)** format — auto-detected and converted
|
||||
- Categories can be **integers or strings** — strings are auto-remapped to integer IDs
|
||||
- `image_id` column is **optional** — generated automatically if missing
|
||||
- **ALWAYS validate unknown datasets** before GPU training (see Dataset Validation section)
|
||||
|
||||
### Dataset Requirements — Image Classification
|
||||
- Dataset must exist on Hub
|
||||
- Must have an **`image` column** (PIL images) and a **`label` column** (integer class IDs or strings)
|
||||
- The label column can be `ClassLabel` type (with names) or plain integers/strings — strings are auto-remapped
|
||||
- Common column names auto-detected: `label`, `labels`, `class`, `fine_label`
|
||||
- **ALWAYS validate unknown datasets** before GPU training (see Dataset Validation section)
|
||||
|
||||
### Dataset Requirements — SAM/SAM2 Segmentation
|
||||
- Dataset must exist on Hub
|
||||
- Must have an **`image` column** (PIL images) and a **`mask` column** (binary ground-truth segmentation mask)
|
||||
- Must have a **prompt** — either:
|
||||
- A **`prompt` column** with JSON containing `{"bbox": [x0,y0,x1,y1]}` or `{"point": [x,y]}`
|
||||
- OR a dedicated **`bbox`** column with `[x0,y0,x1,y1]` values
|
||||
- OR a dedicated **`point`** column with `[x,y]` or `[[x,y],...]` values
|
||||
- Bboxes should be in **xyxy** format (absolute pixel coordinates)
|
||||
- Example dataset: `merve/MicroMat-mini` (image matting with bbox prompts)
|
||||
- **ALWAYS validate unknown datasets** before GPU training (see Dataset Validation section)
|
||||
|
||||
### Critical Settings
|
||||
- **Timeout must exceed expected training time** — Default 30min is TOO SHORT. See directive #6 for recommended values.
|
||||
- **Hub push must be enabled** — `push_to_hub=True`, `hub_model_id="username/model-name"`, token in `secrets`
|
||||
|
||||
## Dataset Validation
|
||||
|
||||
**Validate dataset format BEFORE launching GPU training to prevent the #1 cause of training failures: format mismatches.**
|
||||
|
||||
**ALWAYS validate for** unknown/custom datasets or any dataset you haven't trained with before. **Skip for** `cppe-5` (the default in the training script).
|
||||
|
||||
### Running the Inspector
|
||||
|
||||
**Option 1: Via HF Jobs (recommended — avoids local SSL/dependency issues):**
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": "path/to/dataset_inspector.py",
|
||||
"script_args": ["--dataset", "username/dataset-name", "--split", "train"]
|
||||
})
|
||||
```
|
||||
|
||||
**Option 2: Locally:**
|
||||
```bash
|
||||
uv run scripts/dataset_inspector.py --dataset username/dataset-name --split train
|
||||
```
|
||||
|
||||
**Option 3: Via `HfApi().run_uv_job()` (if hf_jobs MCP unavailable):**
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
api = HfApi()
|
||||
api.run_uv_job(
|
||||
script="scripts/dataset_inspector.py",
|
||||
script_args=["--dataset", "username/dataset-name", "--split", "train"],
|
||||
flavor="cpu-basic",
|
||||
timeout=300,
|
||||
)
|
||||
```
|
||||
|
||||
### Reading Results
|
||||
|
||||
- **`✓ READY`** — Dataset is compatible, use directly
|
||||
- **`✗ NEEDS FORMATTING`** — Needs preprocessing (mapping code provided in output)
|
||||
|
||||
## Automatic Bbox Preprocessing
|
||||
|
||||
The object detection training script (`scripts/object_detection_training.py`) automatically handles bbox format detection (xyxy→xywh conversion), bbox sanitization, `image_id` generation, string category→integer remapping, and dataset truncation. **No manual preprocessing needed** — just ensure the dataset has `objects.bbox` and `objects.category` columns.
|
||||
|
||||
## Training workflow
|
||||
|
||||
Copy this checklist and track progress:
|
||||
|
||||
```
|
||||
Training Progress:
|
||||
- [ ] Step 1: Verify prerequisites (account, token, dataset)
|
||||
- [ ] Step 2: Validate dataset format (run dataset_inspector.py)
|
||||
- [ ] Step 3: Ask user about dataset size and validation split
|
||||
- [ ] Step 4: Prepare training script (OD: scripts/object_detection_training.py, IC: scripts/image_classification_training.py, SAM: scripts/sam_segmentation_training.py)
|
||||
- [ ] Step 5: Save script locally, submit job, and report details
|
||||
```
|
||||
|
||||
**Step 1: Verify prerequisites**
|
||||
|
||||
Follow the Prerequisites Checklist above.
|
||||
|
||||
**Step 2: Validate dataset**
|
||||
|
||||
Run the dataset inspector BEFORE spending GPU time. See "Dataset Validation" section above.
|
||||
|
||||
**Step 3: Ask user preferences**
|
||||
|
||||
ALWAYS use the AskUserQuestion tool with option-style format:
|
||||
|
||||
```python
|
||||
AskUserQuestion({
|
||||
"questions": [
|
||||
{
|
||||
"question": "Do you want to run a quick test with a subset of the data first?",
|
||||
"header": "Dataset Size",
|
||||
"options": [
|
||||
{"label": "Quick test run (10% of data)", "description": "Faster, cheaper (~30-60 min, ~$2-5) to validate setup"},
|
||||
{"label": "Full dataset (Recommended)", "description": "Complete training for best model quality"}
|
||||
],
|
||||
"multiSelect": false
|
||||
},
|
||||
{
|
||||
"question": "Do you want to create a validation split from the training data?",
|
||||
"header": "Split data",
|
||||
"options": [
|
||||
{"label": "Yes (Recommended)", "description": "Automatically split 15% of training data for validation"},
|
||||
{"label": "No", "description": "Use existing validation split from dataset"}
|
||||
],
|
||||
"multiSelect": false
|
||||
},
|
||||
{
|
||||
"question": "Which GPU hardware do you want to use?",
|
||||
"header": "Hardware Flavor",
|
||||
"options": [
|
||||
{"label": "t4-small ($0.40/hr)", "description": "1x T4, 16 GB VRAM — sufficient for all OD models under 100M params"},
|
||||
{"label": "l4x1 ($0.80/hr)", "description": "1x L4, 24 GB VRAM — more headroom for large images or batch sizes"},
|
||||
{"label": "a10g-large ($1.50/hr)", "description": "1x A10G, 24 GB VRAM — faster training, more CPU/RAM"},
|
||||
{"label": "a100-large ($2.50/hr)", "description": "1x A100, 80 GB VRAM — fastest, for very large datasets or image sizes"}
|
||||
],
|
||||
"multiSelect": false
|
||||
}
|
||||
]
|
||||
})
|
||||
```
|
||||
|
||||
**Step 4: Prepare training script**
|
||||
|
||||
For object detection, use [scripts/object_detection_training.py](scripts/object_detection_training.py) as the production-ready template. For image classification, use [scripts/image_classification_training.py](scripts/image_classification_training.py). For SAM/SAM2 segmentation, use [scripts/sam_segmentation_training.py](scripts/sam_segmentation_training.py). All scripts use `HfArgumentParser` — all configuration is passed via CLI arguments in `script_args`, NOT by editing Python variables. For timm model details, see [references/timm_trainer.md](references/timm_trainer.md). For SAM2 training details, see [references/finetune_sam2_trainer.md](references/finetune_sam2_trainer.md).
|
||||
|
||||
**Step 5: Save script, submit job, and report**
|
||||
|
||||
1. **Save the script locally** to `submitted_jobs/` in the workspace root (create if needed) with a descriptive name like `training_<dataset>_<YYYYMMDD_HHMMSS>.py`. Tell the user the path.
|
||||
2. **Submit** using `hf_jobs` MCP tool (preferred) or `HfApi().run_uv_job()` — see directive #1 for both methods. Pass all config via `script_args`.
|
||||
3. **Report** the job ID (from `.id` attribute), monitoring URL, Trackio dashboard (`https://huggingface.co/spaces/{username}/trackio`), expected time, and estimated cost.
|
||||
4. **Wait for user** to request status checks — don't poll automatically. Training jobs run asynchronously and can take hours.
|
||||
|
||||
## Critical directives
|
||||
|
||||
These rules prevent common failures. Follow them exactly.
|
||||
|
||||
### 1. Job submission: `hf_jobs` MCP tool vs Python API
|
||||
|
||||
**`hf_jobs()` is an MCP tool, NOT a Python function.** Do NOT try to import it from `huggingface_hub`. Call it as a tool:
|
||||
|
||||
```
|
||||
hf_jobs("uv", {"script": training_script_content, "flavor": "a10g-large", "timeout": "4h", "secrets": {"HF_TOKEN": "$HF_TOKEN"}})
|
||||
```
|
||||
|
||||
**If `hf_jobs` MCP tool is unavailable**, use the Python API directly:
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi, get_token
|
||||
api = HfApi()
|
||||
job_info = api.run_uv_job(
|
||||
script="path/to/training_script.py", # file PATH, NOT content
|
||||
script_args=["--dataset_name", "cppe-5", ...],
|
||||
flavor="a10g-large",
|
||||
timeout=14400, # seconds (4 hours)
|
||||
env={"PYTHONUNBUFFERED": "1"},
|
||||
secrets={"HF_TOKEN": get_token()}, # MUST use get_token(), NOT "$HF_TOKEN"
|
||||
)
|
||||
print(f"Job ID: {job_info.id}")
|
||||
```
|
||||
|
||||
**Critical differences between the two methods:**
|
||||
|
||||
| | `hf_jobs` MCP tool | `HfApi().run_uv_job()` |
|
||||
|---|---|---|
|
||||
| `script` param | Python code string or URL (NOT local paths) | File path to `.py` file (NOT content) |
|
||||
| Token in secrets | `"$HF_TOKEN"` (auto-replaced) | `get_token()` (actual token value) |
|
||||
| Timeout format | String (`"4h"`) | Seconds (`14400`) |
|
||||
|
||||
**Rules for both methods:**
|
||||
- The training script MUST include PEP 723 inline metadata with dependencies
|
||||
- Do NOT use `image` or `command` parameters (those belong to `run_job()`, not `run_uv_job()`)
|
||||
|
||||
### 2. Authentication via job secrets + explicit hub_token injection
|
||||
|
||||
**Job config** MUST include the token in secrets — syntax depends on submission method (see table above).
|
||||
|
||||
**Training script requirement:** The Transformers `Trainer` calls `create_repo(token=self.args.hub_token)` during `__init__()` when `push_to_hub=True`. The training script MUST inject `HF_TOKEN` into `training_args.hub_token` AFTER parsing args but BEFORE creating the `Trainer`. The template `scripts/object_detection_training.py` already includes this:
|
||||
|
||||
```python
|
||||
hf_token = os.environ.get("HF_TOKEN")
|
||||
if training_args.push_to_hub and not training_args.hub_token:
|
||||
if hf_token:
|
||||
training_args.hub_token = hf_token
|
||||
```
|
||||
|
||||
If you write a custom script, you MUST include this token injection before the `Trainer(...)` call.
|
||||
|
||||
- Do NOT call `login()` in custom scripts unless replicating the full pattern from `scripts/object_detection_training.py`
|
||||
- Do NOT rely on implicit token resolution (`hub_token=None`) — unreliable in Jobs
|
||||
- See the `hugging-face-jobs` skill → *Token Usage Guide* for full details
|
||||
|
||||
### 3. JobInfo attribute
|
||||
|
||||
Access the job identifier using `.id` (NOT `.job_id` or `.name` — these don't exist):
|
||||
|
||||
```python
|
||||
job_info = api.run_uv_job(...) # or hf_jobs("uv", {...})
|
||||
job_id = job_info.id # Correct -- returns string like "687fb701029421ae5549d998"
|
||||
```
|
||||
|
||||
### 4. Required training flags and HfArgumentParser boolean syntax
|
||||
|
||||
`scripts/object_detection_training.py` uses `HfArgumentParser` — all config is passed via `script_args`. Boolean arguments have two syntaxes:
|
||||
|
||||
- **`bool` fields** (e.g., `push_to_hub`, `do_train`): Use as bare flags (`--push_to_hub`) or negate with `--no_` prefix (`--no_remove_unused_columns`)
|
||||
- **`Optional[bool]` fields** (e.g., `greater_is_better`): MUST pass explicit value (`--greater_is_better True`). Bare `--greater_is_better` causes `error: expected one argument`
|
||||
|
||||
Required flags for object detection:
|
||||
|
||||
```
|
||||
--no_remove_unused_columns # MUST: preserves image column for pixel_values
|
||||
--no_eval_do_concat_batches # MUST: images have different numbers of target boxes
|
||||
--push_to_hub # MUST: environment is ephemeral
|
||||
--hub_model_id username/model-name
|
||||
--metric_for_best_model eval_map
|
||||
--greater_is_better True # MUST pass "True" explicitly (Optional[bool])
|
||||
--do_train
|
||||
--do_eval
|
||||
```
|
||||
|
||||
Required flags for image classification:
|
||||
|
||||
```
|
||||
--no_remove_unused_columns # MUST: preserves image column for pixel_values
|
||||
--push_to_hub # MUST: environment is ephemeral
|
||||
--hub_model_id username/model-name
|
||||
--metric_for_best_model eval_accuracy
|
||||
--greater_is_better True # MUST pass "True" explicitly (Optional[bool])
|
||||
--do_train
|
||||
--do_eval
|
||||
```
|
||||
|
||||
Required flags for SAM/SAM2 segmentation:
|
||||
|
||||
```
|
||||
--remove_unused_columns False # MUST: preserves input_boxes/input_points
|
||||
--push_to_hub # MUST: environment is ephemeral
|
||||
--hub_model_id username/model-name
|
||||
--do_train
|
||||
--prompt_type bbox # or "point"
|
||||
--dataloader_pin_memory False # MUST: avoids pin_memory issues with custom collator
|
||||
```
|
||||
|
||||
### 5. Timeout management
|
||||
|
||||
Default 30 min is TOO SHORT for object detection. Set minimum 2-4 hours. Add 30% buffer for model loading, preprocessing, and Hub push.
|
||||
|
||||
| Scenario | Timeout |
|
||||
|----------|---------|
|
||||
| Quick test (100-200 images, 5-10 epochs) | 1h |
|
||||
| Development (500-1K images, 15-20 epochs) | 2-3h |
|
||||
| Production (1K-5K images, 30 epochs) | 4-6h |
|
||||
| Large dataset (5K+ images) | 6-12h |
|
||||
|
||||
### 6. Trackio monitoring
|
||||
|
||||
Trackio is **always enabled** in the object detection training script — it calls `trackio.init()` and `trackio.finish()` automatically. No need to pass `--report_to trackio`. The project name is taken from `--output_dir` and the run name from `--run_name`. For image classification, pass `--report_to trackio` in `TrainingArguments`.
|
||||
|
||||
Dashboard at: `https://huggingface.co/spaces/{username}/trackio`
|
||||
|
||||
## Model & hardware selection
|
||||
|
||||
### Recommended object detection models
|
||||
|
||||
| Model | Params | Use case |
|
||||
|-------|--------|----------|
|
||||
| `ustc-community/dfine-small-coco` | 10.4M | Best starting point — fast, cheap, SOTA quality |
|
||||
| `PekingU/rtdetr_v2_r18vd` | 20.2M | Lightweight real-time detector |
|
||||
| `ustc-community/dfine-large-coco` | 31.4M | Higher accuracy, still efficient |
|
||||
| `PekingU/rtdetr_v2_r50vd` | 43M | Strong real-time baseline |
|
||||
| `ustc-community/dfine-xlarge-obj365` | 63.5M | Best accuracy (pretrained on Objects365) |
|
||||
| `PekingU/rtdetr_v2_r101vd` | 76M | Largest RT-DETR v2 variant |
|
||||
|
||||
Start with `ustc-community/dfine-small-coco` for fast iteration. Move to D-FINE Large or RT-DETR v2 R50 for better accuracy.
|
||||
|
||||
### Recommended image classification models
|
||||
|
||||
All `timm/` models work out of the box via `AutoModelForImageClassification` (loaded as `TimmWrapperForImageClassification`). See [references/timm_trainer.md](references/timm_trainer.md) for details.
|
||||
|
||||
| Model | Params | Use case |
|
||||
|-------|--------|----------|
|
||||
| `timm/mobilenetv3_small_100.lamb_in1k` | 2.5M | Ultra-lightweight — mobile/edge, fastest training |
|
||||
| `timm/mobilevit_s.cvnets_in1k` | 5.6M | Mobile transformer — good accuracy/speed trade-off |
|
||||
| `timm/resnet50.a1_in1k` | 25.6M | Strong CNN baseline — reliable, well-studied |
|
||||
| `timm/vit_base_patch16_dinov3.lvd1689m` | 86.6M | Best accuracy — DINOv3 self-supervised ViT |
|
||||
|
||||
Start with `timm/mobilenetv3_small_100.lamb_in1k` for fast iteration. Move to `timm/resnet50.a1_in1k` or `timm/vit_base_patch16_dinov3.lvd1689m` for better accuracy.
|
||||
|
||||
### Recommended SAM/SAM2 segmentation models
|
||||
|
||||
| Model | Params | Use case |
|
||||
|-------|--------|----------|
|
||||
| `facebook/sam2.1-hiera-tiny` | 38.9M | Fastest SAM2 — good for quick experiments |
|
||||
| `facebook/sam2.1-hiera-small` | 46.0M | Best starting point — good quality/speed balance |
|
||||
| `facebook/sam2.1-hiera-base-plus` | 80.8M | Higher capacity for complex segmentation |
|
||||
| `facebook/sam2.1-hiera-large` | 224.4M | Best SAM2 accuracy — requires more VRAM |
|
||||
| `facebook/sam-vit-base` | 93.7M | Original SAM — ViT-B backbone |
|
||||
| `facebook/sam-vit-large` | 312.3M | Original SAM — ViT-L backbone |
|
||||
| `facebook/sam-vit-huge` | 641.1M | Original SAM — ViT-H, best SAM v1 accuracy |
|
||||
|
||||
Start with `facebook/sam2.1-hiera-small` for fast iteration. SAM2 models are generally more efficient than SAM v1 at similar quality. Only the mask decoder is trained by default (vision and prompt encoders are frozen).
|
||||
|
||||
### Hardware recommendation
|
||||
|
||||
All recommended OD and IC models are under 100M params — **`t4-small` (16 GB VRAM, $0.40/hr) is sufficient for all of them.** Image classification models are generally smaller and faster than object detection models — `t4-small` handles even ViT-Base comfortably. For SAM2 models up to `hiera-base-plus`, `t4-small` is sufficient since only the mask decoder is trained. For `sam2.1-hiera-large` or SAM v1 models, use `l4x1` or `a10g-large`. Only upgrade if you hit OOM from large batch sizes — reduce batch size first before switching hardware. Common upgrade path: `t4-small` → `l4x1` ($0.80/hr, 24 GB) → `a10g-large` ($1.50/hr, 24 GB).
|
||||
|
||||
For full hardware flavor list: refer to the `hugging-face-jobs` skill. For cost estimation: run `scripts/estimate_cost.py`.
|
||||
|
||||
## Quick start — Object Detection
|
||||
|
||||
The `script_args` below are the same for both submission methods. See directive #1 for the critical differences between them.
|
||||
|
||||
```python
|
||||
OD_SCRIPT_ARGS = [
|
||||
"--model_name_or_path", "ustc-community/dfine-small-coco",
|
||||
"--dataset_name", "cppe-5",
|
||||
"--image_square_size", "640",
|
||||
"--output_dir", "dfine_finetuned",
|
||||
"--num_train_epochs", "30",
|
||||
"--per_device_train_batch_size", "8",
|
||||
"--learning_rate", "5e-5",
|
||||
"--eval_strategy", "epoch",
|
||||
"--save_strategy", "epoch",
|
||||
"--save_total_limit", "2",
|
||||
"--load_best_model_at_end",
|
||||
"--metric_for_best_model", "eval_map",
|
||||
"--greater_is_better", "True",
|
||||
"--no_remove_unused_columns",
|
||||
"--no_eval_do_concat_batches",
|
||||
"--push_to_hub",
|
||||
"--hub_model_id", "username/model-name",
|
||||
"--do_train",
|
||||
"--do_eval",
|
||||
]
|
||||
```
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi, get_token
|
||||
api = HfApi()
|
||||
job_info = api.run_uv_job(
|
||||
script="scripts/object_detection_training.py",
|
||||
script_args=OD_SCRIPT_ARGS,
|
||||
flavor="t4-small",
|
||||
timeout=14400,
|
||||
env={"PYTHONUNBUFFERED": "1"},
|
||||
secrets={"HF_TOKEN": get_token()},
|
||||
)
|
||||
print(f"Job ID: {job_info.id}")
|
||||
```
|
||||
|
||||
### Key OD `script_args`
|
||||
|
||||
- `--model_name_or_path` — recommended: `"ustc-community/dfine-small-coco"` (see model table above)
|
||||
- `--dataset_name` — the Hub dataset ID
|
||||
- `--image_square_size` — 480 (fast iteration) or 800 (better accuracy)
|
||||
- `--hub_model_id` — `"username/model-name"` for Hub persistence
|
||||
- `--num_train_epochs` — 30 typical for convergence
|
||||
- `--train_val_split` — fraction to split for validation (default 0.15), set if dataset lacks a validation split
|
||||
- `--max_train_samples` — truncate training set (useful for quick test runs, e.g. `"785"` for ~10% of a 7.8K dataset)
|
||||
- `--max_eval_samples` — truncate evaluation set
|
||||
|
||||
## Quick start — Image Classification
|
||||
|
||||
```python
|
||||
IC_SCRIPT_ARGS = [
|
||||
"--model_name_or_path", "timm/mobilenetv3_small_100.lamb_in1k",
|
||||
"--dataset_name", "ethz/food101",
|
||||
"--output_dir", "food101_classifier",
|
||||
"--num_train_epochs", "5",
|
||||
"--per_device_train_batch_size", "32",
|
||||
"--per_device_eval_batch_size", "32",
|
||||
"--learning_rate", "5e-5",
|
||||
"--eval_strategy", "epoch",
|
||||
"--save_strategy", "epoch",
|
||||
"--save_total_limit", "2",
|
||||
"--load_best_model_at_end",
|
||||
"--metric_for_best_model", "eval_accuracy",
|
||||
"--greater_is_better", "True",
|
||||
"--no_remove_unused_columns",
|
||||
"--push_to_hub",
|
||||
"--hub_model_id", "username/food101-classifier",
|
||||
"--do_train",
|
||||
"--do_eval",
|
||||
]
|
||||
```
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi, get_token
|
||||
api = HfApi()
|
||||
job_info = api.run_uv_job(
|
||||
script="scripts/image_classification_training.py",
|
||||
script_args=IC_SCRIPT_ARGS,
|
||||
flavor="t4-small",
|
||||
timeout=7200,
|
||||
env={"PYTHONUNBUFFERED": "1"},
|
||||
secrets={"HF_TOKEN": get_token()},
|
||||
)
|
||||
print(f"Job ID: {job_info.id}")
|
||||
```
|
||||
|
||||
### Key IC `script_args`
|
||||
|
||||
- `--model_name_or_path` — any `timm/` model or Transformers classification model (see model table above)
|
||||
- `--dataset_name` — the Hub dataset ID
|
||||
- `--image_column_name` — column containing PIL images (default: `"image"`)
|
||||
- `--label_column_name` — column containing class labels (default: `"label"`)
|
||||
- `--hub_model_id` — `"username/model-name"` for Hub persistence
|
||||
- `--num_train_epochs` — 3-5 typical for classification (fewer than OD)
|
||||
- `--per_device_train_batch_size` — 16-64 (classification models use less memory than OD)
|
||||
- `--train_val_split` — fraction to split for validation (default 0.15), set if dataset lacks a validation split
|
||||
- `--max_train_samples` / `--max_eval_samples` — truncate for quick tests
|
||||
|
||||
## Quick start — SAM/SAM2 Segmentation
|
||||
|
||||
```python
|
||||
SAM_SCRIPT_ARGS = [
|
||||
"--model_name_or_path", "facebook/sam2.1-hiera-small",
|
||||
"--dataset_name", "merve/MicroMat-mini",
|
||||
"--prompt_type", "bbox",
|
||||
"--prompt_column_name", "prompt",
|
||||
"--output_dir", "sam2-finetuned",
|
||||
"--num_train_epochs", "30",
|
||||
"--per_device_train_batch_size", "4",
|
||||
"--learning_rate", "1e-5",
|
||||
"--logging_steps", "1",
|
||||
"--save_strategy", "epoch",
|
||||
"--save_total_limit", "2",
|
||||
"--remove_unused_columns", "False",
|
||||
"--dataloader_pin_memory", "False",
|
||||
"--push_to_hub",
|
||||
"--hub_model_id", "username/sam2-finetuned",
|
||||
"--do_train",
|
||||
"--report_to", "trackio",
|
||||
]
|
||||
```
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi, get_token
|
||||
api = HfApi()
|
||||
job_info = api.run_uv_job(
|
||||
script="scripts/sam_segmentation_training.py",
|
||||
script_args=SAM_SCRIPT_ARGS,
|
||||
flavor="t4-small",
|
||||
timeout=7200,
|
||||
env={"PYTHONUNBUFFERED": "1"},
|
||||
secrets={"HF_TOKEN": get_token()},
|
||||
)
|
||||
print(f"Job ID: {job_info.id}")
|
||||
```
|
||||
|
||||
### Key SAM `script_args`
|
||||
|
||||
- `--model_name_or_path` — SAM or SAM2 model (see model table above); auto-detects SAM vs SAM2
|
||||
- `--dataset_name` — the Hub dataset ID (e.g., `"merve/MicroMat-mini"`)
|
||||
- `--prompt_type` — `"bbox"` or `"point"` — type of prompt in the dataset
|
||||
- `--prompt_column_name` — column with JSON-encoded prompts (default: `"prompt"`)
|
||||
- `--bbox_column_name` — dedicated bbox column (alternative to JSON prompt column)
|
||||
- `--point_column_name` — dedicated point column (alternative to JSON prompt column)
|
||||
- `--mask_column_name` — column with ground-truth masks (default: `"mask"`)
|
||||
- `--hub_model_id` — `"username/model-name"` for Hub persistence
|
||||
- `--num_train_epochs` — 20-30 typical for SAM fine-tuning
|
||||
- `--per_device_train_batch_size` — 2-4 (SAM models use significant memory)
|
||||
- `--freeze_vision_encoder` / `--freeze_prompt_encoder` — freeze encoder weights (default: both frozen, only mask decoder trains)
|
||||
- `--train_val_split` — fraction to split for validation (default 0.1)
|
||||
|
||||
## Checking job status
|
||||
|
||||
**MCP tool (if available):**
|
||||
```
|
||||
hf_jobs("ps") # List all jobs
|
||||
hf_jobs("logs", {"job_id": "your-job-id"}) # View logs
|
||||
hf_jobs("inspect", {"job_id": "your-job-id"}) # Job details
|
||||
```
|
||||
|
||||
**Python API fallback:**
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
api = HfApi()
|
||||
api.list_jobs() # List all jobs
|
||||
api.get_job_logs(job_id="your-job-id") # View logs
|
||||
api.get_job(job_id="your-job-id") # Job details
|
||||
```
|
||||
|
||||
## Common failure modes
|
||||
|
||||
### OOM (CUDA out of memory)
|
||||
Reduce `per_device_train_batch_size` (try 4, then 2), reduce `IMAGE_SIZE`, or upgrade hardware.
|
||||
|
||||
### Dataset format errors
|
||||
Run `scripts/dataset_inspector.py` first. The training script auto-detects xyxy vs xywh, converts string categories to integer IDs, and adds `image_id` if missing. Ensure `objects.bbox` contains 4-value coordinate lists in absolute pixels and `objects.category` contains either integer IDs or string labels.
|
||||
|
||||
### Hub push failures (401)
|
||||
Verify: (1) job secrets include token (see directive #2), (2) script sets `training_args.hub_token` BEFORE creating the `Trainer`, (3) `push_to_hub=True` is set, (4) correct `hub_model_id`, (5) token has write permissions.
|
||||
|
||||
### Job timeout
|
||||
Increase timeout (see directive #5 table), reduce epochs/dataset, or use checkpoint strategy with `hub_strategy="every_save"`.
|
||||
|
||||
### KeyError: 'test' (missing test split)
|
||||
The object detection training script handles this gracefully — it falls back to the `validation` split. Ensure you're using the latest `scripts/object_detection_training.py`.
|
||||
|
||||
### Single-class dataset: "iteration over a 0-d tensor"
|
||||
`torchmetrics.MeanAveragePrecision` returns scalar (0-d) tensors for per-class metrics when there's only one class. The template `scripts/object_detection_training.py` handles this by calling `.unsqueeze(0)` on these tensors. Ensure you're using the latest template.
|
||||
|
||||
### Poor detection performance (mAP < 0.15)
|
||||
Increase epochs (30-50), ensure 500+ images, check per-class mAP for imbalanced classes, try different learning rates (1e-5 to 1e-4), increase image size.
|
||||
|
||||
For comprehensive troubleshooting: see [references/reliability_principles.md](references/reliability_principles.md)
|
||||
|
||||
## Reference files
|
||||
|
||||
- [scripts/object_detection_training.py](scripts/object_detection_training.py) — Production-ready object detection training script
|
||||
- [scripts/image_classification_training.py](scripts/image_classification_training.py) — Production-ready image classification training script (supports timm models)
|
||||
- [scripts/sam_segmentation_training.py](scripts/sam_segmentation_training.py) — Production-ready SAM/SAM2 segmentation training script (bbox & point prompts)
|
||||
- [scripts/dataset_inspector.py](scripts/dataset_inspector.py) — Validate dataset format for OD, classification, and SAM segmentation
|
||||
- [scripts/estimate_cost.py](scripts/estimate_cost.py) — Estimate training costs for any vision model (includes SAM/SAM2)
|
||||
- [references/object_detection_training_notebook.md](references/object_detection_training_notebook.md) — Object detection training workflow, augmentation strategies, and training patterns
|
||||
- [references/image_classification_training_notebook.md](references/image_classification_training_notebook.md) — Image classification training workflow with ViT, preprocessing, and evaluation
|
||||
- [references/finetune_sam2_trainer.md](references/finetune_sam2_trainer.md) — SAM2 fine-tuning walkthrough with MicroMat dataset, DiceCE loss, and Trainer integration
|
||||
- [references/timm_trainer.md](references/timm_trainer.md) — Using timm models with HF Trainer (TimmWrapper, transforms, full example)
|
||||
- [references/hub_saving.md](references/hub_saving.md) — Detailed Hub persistence guide and verification checklist
|
||||
- [references/reliability_principles.md](references/reliability_principles.md) — Failure prevention principles from production experience
|
||||
|
||||
## External links
|
||||
|
||||
- [Transformers Object Detection Guide](https://huggingface.co/docs/transformers/tasks/object_detection)
|
||||
- [Transformers Image Classification Guide](https://huggingface.co/docs/transformers/tasks/image_classification)
|
||||
- [DETR Model Documentation](https://huggingface.co/docs/transformers/model_doc/detr)
|
||||
- [ViT Model Documentation](https://huggingface.co/docs/transformers/model_doc/vit)
|
||||
- [HF Jobs Guide](https://huggingface.co/docs/huggingface_hub/guides/jobs) — Main Jobs documentation
|
||||
- [HF Jobs Configuration](https://huggingface.co/docs/hub/en/jobs-configuration) — Hardware, secrets, timeouts, namespaces
|
||||
- [HF Jobs CLI Reference](https://huggingface.co/docs/huggingface_hub/guides/cli#hf-jobs) — Command line interface
|
||||
- [Object Detection Models](https://huggingface.co/models?pipeline_tag=object-detection)
|
||||
- [Image Classification Models](https://huggingface.co/models?pipeline_tag=image-classification)
|
||||
- [SAM2 Model Documentation](https://huggingface.co/docs/transformers/model_doc/sam2)
|
||||
- [SAM Model Documentation](https://huggingface.co/docs/transformers/model_doc/sam)
|
||||
- [Object Detection Datasets](https://huggingface.co/datasets?task_categories=task_categories:object-detection)
|
||||
- [Image Classification Datasets](https://huggingface.co/datasets?task_categories=task_categories:image-classification)
|
||||
@@ -0,0 +1,254 @@
|
||||
# Fine-tuning SAM2 with HF Trainer
|
||||
|
||||
Fine-tune SAM2.1 on a small part of the MicroMat dataset for image matting,
|
||||
using the Hugging Face Trainer with a custom loss function.
|
||||
|
||||
```python
|
||||
!pip install -q transformers datasets monai trackio
|
||||
```
|
||||
|
||||
## Load and explore the dataset
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("merve/MicroMat-mini", split="train")
|
||||
dataset
|
||||
```
|
||||
|
||||
```python
|
||||
dataset = dataset.train_test_split(test_size=0.1)
|
||||
train_ds = dataset["train"]
|
||||
val_ds = dataset["test"]
|
||||
```
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
train_ds[0]
|
||||
```
|
||||
|
||||
```python
|
||||
json.loads(train_ds["prompt"][0])["bbox"]
|
||||
```
|
||||
|
||||
## Visualize a sample
|
||||
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def show_mask(mask, ax, bbox):
|
||||
color = np.array([0.12, 0.56, 1.0, 0.6])
|
||||
mask = np.array(mask)
|
||||
h, w = mask.shape
|
||||
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, 4)
|
||||
ax.imshow(mask_image)
|
||||
x0, y0, x1, y1 = bbox
|
||||
ax.add_patch(
|
||||
plt.Rectangle(
|
||||
(x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor="lime", linewidth=2
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
example = train_ds[0]
|
||||
image = np.array(example["image"])
|
||||
ground_truth_mask = np.array(example["mask"])
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(image)
|
||||
show_mask(ground_truth_mask, ax, json.loads(example["prompt"])["bbox"])
|
||||
ax.set_title("Ground truth mask")
|
||||
ax.set_axis_off()
|
||||
plt.show()
|
||||
```
|
||||
|
||||
## Build the dataset and collator
|
||||
|
||||
`SAMDataset` wraps each sample into the format expected by the SAM2 processor.
|
||||
Ground-truth masks are stored under the key `"labels"` so the Trainer
|
||||
automatically pops them before calling `model.forward()`.
|
||||
|
||||
```python
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SAMDataset(Dataset):
|
||||
def __init__(self, dataset, processor):
|
||||
self.dataset = dataset
|
||||
self.processor = processor
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.dataset[idx]
|
||||
image = item["image"]
|
||||
prompt = json.loads(item["prompt"])["bbox"]
|
||||
inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")
|
||||
inputs["labels"] = (np.array(item["mask"]) > 0).astype(np.float32)
|
||||
inputs["original_image_size"] = torch.tensor(image.size[::-1])
|
||||
return inputs
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
pixel_values = torch.cat([item["pixel_values"] for item in batch], dim=0)
|
||||
original_sizes = torch.stack([item["original_sizes"] for item in batch])
|
||||
input_boxes = torch.cat([item["input_boxes"] for item in batch], dim=0)
|
||||
labels = torch.cat(
|
||||
[
|
||||
F.interpolate(
|
||||
torch.as_tensor(x["labels"]).unsqueeze(0).unsqueeze(0).float(),
|
||||
size=(256, 256),
|
||||
mode="nearest",
|
||||
)
|
||||
for x in batch
|
||||
],
|
||||
dim=0,
|
||||
).long()
|
||||
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"original_sizes": original_sizes,
|
||||
"input_boxes": input_boxes,
|
||||
"labels": labels,
|
||||
"original_image_size": torch.stack(
|
||||
[item["original_image_size"] for item in batch]
|
||||
),
|
||||
"multimask_output": False,
|
||||
}
|
||||
```
|
||||
|
||||
```python
|
||||
from transformers import Sam2Processor
|
||||
|
||||
processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-small")
|
||||
|
||||
train_dataset = SAMDataset(dataset=train_ds, processor=processor)
|
||||
val_dataset = SAMDataset(dataset=val_ds, processor=processor)
|
||||
```
|
||||
|
||||
## Load model and freeze encoder layers
|
||||
|
||||
```python
|
||||
from transformers import Sam2Model
|
||||
|
||||
model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-small")
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
|
||||
param.requires_grad_(False)
|
||||
```
|
||||
|
||||
## Inference before training
|
||||
|
||||
```python
|
||||
item = val_ds[1]
|
||||
img = item["image"]
|
||||
bbox = json.loads(item["prompt"])["bbox"]
|
||||
inputs = processor(images=img, input_boxes=[[bbox]], return_tensors="pt").to(
|
||||
model.device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
|
||||
preds = masks.squeeze(0)
|
||||
mask = (preds[0] > 0).cpu().numpy()
|
||||
|
||||
overlay = np.asarray(img, dtype=np.uint8).copy()
|
||||
overlay[mask] = 0.55 * overlay[mask] + 0.45 * np.array([0, 255, 0], dtype=np.float32)
|
||||
|
||||
plt.imshow(overlay)
|
||||
plt.title("Before training")
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
```
|
||||
|
||||
## Define custom loss
|
||||
|
||||
SAM2 does not compute loss in its `forward()`, so we provide a
|
||||
`compute_loss_func` to the Trainer. The Trainer pops `"labels"` from the
|
||||
batch before calling `model(**inputs)`, then passes `(outputs, labels)` to
|
||||
this function.
|
||||
|
||||
```python
|
||||
import monai
|
||||
from transformers import Trainer, TrainingArguments
|
||||
import trackio
|
||||
|
||||
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction="mean")
|
||||
|
||||
|
||||
def compute_loss(outputs, labels, num_items_in_batch=None):
|
||||
predicted_masks = outputs.pred_masks.squeeze(1)
|
||||
return seg_loss(predicted_masks, labels.float())
|
||||
```
|
||||
|
||||
## Train with Trainer
|
||||
|
||||
Key settings:
|
||||
- `remove_unused_columns=False`: the Trainer must keep `input_boxes`,
|
||||
`original_sizes`, etc. that are not in the model's `forward()` signature.
|
||||
- `compute_loss_func`: our custom DiceCE loss.
|
||||
- `report_to="trackio"`: logs the training loss to trackio.
|
||||
|
||||
```python
|
||||
training_args = TrainingArguments(
|
||||
output_dir="sam2-finetuned",
|
||||
num_train_epochs=30,
|
||||
per_device_train_batch_size=4,
|
||||
learning_rate=1e-5,
|
||||
weight_decay=0,
|
||||
logging_steps=1,
|
||||
save_strategy="epoch",
|
||||
save_total_limit=2,
|
||||
remove_unused_columns=False,
|
||||
dataloader_pin_memory=False,
|
||||
report_to="trackio",
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
data_collator=collate_fn,
|
||||
compute_loss_func=compute_loss,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Inference after training
|
||||
|
||||
```python
|
||||
item = val_ds[1]
|
||||
img = item["image"]
|
||||
bbox = json.loads(item["prompt"])["bbox"]
|
||||
|
||||
inputs = processor(images=img, input_boxes=[[bbox]], return_tensors="pt").to(
|
||||
model.device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
preds = processor.post_process_masks(
|
||||
outputs.pred_masks.cpu(), inputs["original_sizes"]
|
||||
)[0]
|
||||
preds = preds.squeeze(0)
|
||||
mask = (preds[0] > 0).cpu().numpy()
|
||||
|
||||
overlay = np.asarray(img, dtype=np.uint8).copy()
|
||||
overlay[mask] = 0.55 * overlay[mask] + 0.45 * np.array([0, 255, 0], dtype=np.float32)
|
||||
|
||||
plt.imshow(overlay)
|
||||
plt.title("After training")
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
```
|
||||
@@ -0,0 +1,618 @@
|
||||
# Saving Vision Models to Hugging Face Hub
|
||||
|
||||
## Contents
|
||||
- Why Hub Push is Required
|
||||
- Required Configuration (TrainingArguments, job config)
|
||||
- Complete Example
|
||||
- What Gets Saved
|
||||
- Important: Save Image Processor
|
||||
- Checkpoint Saving
|
||||
- Model Card Configuration
|
||||
- Saving Label Mappings
|
||||
- Authentication Methods
|
||||
- Verification Checklist
|
||||
- Repository Setup (automatic/manual creation, naming)
|
||||
- Troubleshooting (401, 403, push failures, inference issues)
|
||||
- Manual Push After Training
|
||||
- Example: Full Production Setup
|
||||
- Inference Example
|
||||
|
||||
---
|
||||
|
||||
**CRITICAL:** Training environments are ephemeral. ALL results are lost when a job completes unless pushed to the Hub.
|
||||
|
||||
## Why Hub Push is Required
|
||||
|
||||
When running on Hugging Face Jobs:
|
||||
- Environment is temporary
|
||||
- All files deleted on job completion
|
||||
- No local disk persistence
|
||||
- Cannot access results after job ends
|
||||
|
||||
**Without Hub push, training is completely wasted.**
|
||||
|
||||
## Required Configuration
|
||||
|
||||
### 1. Training Configuration
|
||||
|
||||
In your TrainingArguments:
|
||||
|
||||
```python
|
||||
from transformers import TrainingArguments
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir="my-object-detector",
|
||||
push_to_hub=True, # Enable Hub push
|
||||
hub_model_id="username/model-name", # Target repository
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Job Configuration
|
||||
|
||||
When submitting the job:
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": training_script_content, # Pass the Python script content directly as a string
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"} # Provide authentication
|
||||
})
|
||||
```
|
||||
|
||||
**The `$HF_TOKEN` syntax references your actual Hugging Face token value.**
|
||||
|
||||
## Complete Example
|
||||
|
||||
```python
|
||||
# train_detector.py
|
||||
# /// script
|
||||
# dependencies = ["transformers", "torch", "torchvision", "datasets"]
|
||||
# ///
|
||||
|
||||
from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoModelForObjectDetection,
|
||||
TrainingArguments,
|
||||
Trainer
|
||||
)
|
||||
from datasets import load_dataset
|
||||
import os
|
||||
import torch
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("cppe-5", split="train")
|
||||
|
||||
# Load model and processor
|
||||
model_name = "facebook/detr-resnet-50"
|
||||
image_processor = AutoImageProcessor.from_pretrained(model_name)
|
||||
model = AutoModelForObjectDetection.from_pretrained(
|
||||
model_name,
|
||||
num_labels=5, # Number of classes
|
||||
ignore_mismatched_sizes=True
|
||||
)
|
||||
|
||||
# Configure with Hub push
|
||||
training_args = TrainingArguments(
|
||||
output_dir="my-detector",
|
||||
num_train_epochs=10,
|
||||
per_device_train_batch_size=8,
|
||||
|
||||
# ✅ CRITICAL: Hub push configuration
|
||||
push_to_hub=True,
|
||||
hub_model_id="myusername/cppe5-detector",
|
||||
|
||||
# Optional: Push strategy
|
||||
hub_strategy="checkpoint", # Push checkpoints during training
|
||||
)
|
||||
|
||||
# ✅ CRITICAL: Authenticate with Hub BEFORE creating Trainer
|
||||
from huggingface_hub import login
|
||||
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("hfjob")
|
||||
if hf_token:
|
||||
login(token=hf_token)
|
||||
training_args.hub_token = hf_token
|
||||
elif training_args.push_to_hub:
|
||||
raise ValueError("HF_TOKEN not found! Add secrets={'HF_TOKEN': '$HF_TOKEN'} to job config.")
|
||||
|
||||
# Define collate function
|
||||
def collate_fn(batch):
|
||||
pixel_values = [item["pixel_values"] for item in batch]
|
||||
labels = [item["labels"] for item in batch]
|
||||
encoding = image_processor.pad(pixel_values, return_tensors="pt")
|
||||
return {
|
||||
"pixel_values": encoding["pixel_values"],
|
||||
"labels": labels
|
||||
}
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
data_collator=collate_fn,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# ✅ Push final model and processor
|
||||
trainer.push_to_hub()
|
||||
image_processor.push_to_hub("myusername/cppe5-detector")
|
||||
|
||||
print("✅ Model saved to: https://huggingface.co/myusername/cppe5-detector")
|
||||
```
|
||||
|
||||
**Submit with authentication:**
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": training_script_content, # Pass script content as a string, NOT a filename
|
||||
"flavor": "a10g-large",
|
||||
"timeout": "4h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"} # ✅ Required!
|
||||
})
|
||||
```
|
||||
|
||||
## What Gets Saved
|
||||
|
||||
When `push_to_hub=True`:
|
||||
|
||||
1. **Model weights** - Final trained parameters
|
||||
2. **Image processor** - Associated preprocessing configuration
|
||||
3. **Configuration** - Model config (config.json) including:
|
||||
- Number of labels/classes
|
||||
- Architecture details (backbone, num_queries, etc.)
|
||||
- Label mappings (id2label, label2id)
|
||||
4. **Training arguments** - Hyperparameters used
|
||||
5. **Model card** - Auto-generated documentation
|
||||
6. **Checkpoints** - If `save_strategy="steps"` enabled
|
||||
|
||||
## Important: Save Image Processor
|
||||
|
||||
**Object detection models require the image processor to be saved separately:**
|
||||
|
||||
```python
|
||||
# After training completes
|
||||
trainer.push_to_hub()
|
||||
|
||||
# ✅ Also push the image processor
|
||||
image_processor.push_to_hub(
|
||||
repo_id="username/model-name",
|
||||
commit_message="Upload image processor"
|
||||
)
|
||||
```
|
||||
|
||||
**Why this matters:**
|
||||
- Models need specific image preprocessing (resizing, normalization)
|
||||
- Image processor contains critical configuration
|
||||
- Without it, model cannot be used for inference
|
||||
|
||||
## Checkpoint Saving
|
||||
|
||||
Save intermediate checkpoints during training:
|
||||
|
||||
```python
|
||||
TrainingArguments(
|
||||
output_dir="my-detector",
|
||||
push_to_hub=True,
|
||||
hub_model_id="username/my-detector",
|
||||
|
||||
# Checkpoint configuration
|
||||
save_strategy="steps",
|
||||
save_steps=500, # Save every 500 steps
|
||||
save_total_limit=3, # Keep only last 3 checkpoints
|
||||
hub_strategy="checkpoint", # Push checkpoints to Hub
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Resume training if job fails
|
||||
- Compare checkpoint performance
|
||||
- Use intermediate models
|
||||
- Track training progress
|
||||
|
||||
**Checkpoints are pushed to:** `username/my-detector` (same repo)
|
||||
|
||||
## Model Card Configuration
|
||||
|
||||
Add metadata for better discoverability:
|
||||
|
||||
```python
|
||||
# At the end of training script
|
||||
model.push_to_hub(
|
||||
"username/my-detector",
|
||||
commit_message="Upload trained object detection model",
|
||||
tags=["object-detection", "vision", "cppe-5"],
|
||||
model_card_kwargs={
|
||||
"license": "apache-2.0",
|
||||
"dataset": "cppe-5",
|
||||
"metrics": ["map", "recall", "precision"],
|
||||
"pipeline_tag": "object-detection",
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Saving Label Mappings
|
||||
|
||||
**Critical for object detection:** Save class labels with the model:
|
||||
|
||||
```python
|
||||
# Define your label mappings
|
||||
id2label = {0: "Coverall", 1: "Face_Shield", 2: "Gloves", 3: "Goggles", 4: "Mask"}
|
||||
label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
# Update model config before training
|
||||
model.config.id2label = id2label
|
||||
model.config.label2id = label2id
|
||||
|
||||
# Now train and push
|
||||
trainer.train()
|
||||
trainer.push_to_hub()
|
||||
```
|
||||
|
||||
**Without label mappings:**
|
||||
- Model outputs will be numeric IDs only
|
||||
- No human-readable class names
|
||||
- Difficult to interpret results
|
||||
|
||||
## Authentication Methods
|
||||
|
||||
For a complete guide on token types, `$HF_TOKEN` automatic replacement, `secrets` vs `env` differences, and security best practices, see the `hugging-face-jobs` skill → *Token Usage Guide*.
|
||||
|
||||
**Recommended:** Always pass tokens via `secrets` (encrypted server-side):
|
||||
|
||||
```python
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"} # ✅ Automatic replacement with your logged-in token
|
||||
```
|
||||
|
||||
## Verification Checklist
|
||||
|
||||
Before submitting any training job, verify:
|
||||
|
||||
- [ ] `push_to_hub=True` in TrainingArguments
|
||||
- [ ] `hub_model_id` is specified (format: `username/model-name`)
|
||||
- [ ] Image processor will be saved separately
|
||||
- [ ] Label mappings (id2label, label2id) are configured
|
||||
- [ ] Repository name doesn't conflict with existing repos
|
||||
- [ ] You have write access to the target namespace
|
||||
|
||||
## Repository Setup
|
||||
|
||||
### Automatic Creation
|
||||
|
||||
If repository doesn't exist, it's created automatically when first pushing.
|
||||
|
||||
### Manual Creation
|
||||
|
||||
Create repository before training:
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(
|
||||
repo_id="username/detector-name",
|
||||
repo_type="model",
|
||||
private=False, # or True for private repo
|
||||
)
|
||||
```
|
||||
|
||||
### Repository Naming
|
||||
|
||||
**Valid names:**
|
||||
- `username/detr-cppe5`
|
||||
- `username/yolos-object-detector`
|
||||
- `organization/custom-detector`
|
||||
|
||||
**Invalid names:**
|
||||
- `detector-name` (missing username)
|
||||
- `username/detector name` (spaces not allowed)
|
||||
- `username/DETECTOR` (uppercase discouraged)
|
||||
|
||||
**Recommended naming:**
|
||||
- Include model architecture: `detr-`, `yolos-`, `deta-`
|
||||
- Include dataset: `-cppe5`, `-coco`, `-voc`
|
||||
- Be descriptive: `detr-resnet50-cppe5` > `model1`
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Error: 401 Unauthorized
|
||||
|
||||
**Cause:** HF_TOKEN not provided, invalid, or not authenticated before Trainer init
|
||||
|
||||
**Solutions:**
|
||||
1. Verify `secrets={"HF_TOKEN": "$HF_TOKEN"}` in job config
|
||||
2. Verify script calls `login(token=hf_token)` AND sets `training_args.hub_token = hf_token` BEFORE creating the `Trainer`
|
||||
3. Check you're logged in locally: `hf auth whoami`
|
||||
4. Re-login: `hf auth login`
|
||||
|
||||
**Root cause:** The `Trainer` calls `create_repo(token=self.args.hub_token)` during `__init__()` when `push_to_hub=True`. Relying on implicit env-var token resolution is unreliable in Jobs. Calling `login()` saves the token globally, and setting `training_args.hub_token` ensures the Trainer passes it explicitly to all Hub API calls.
|
||||
|
||||
### Error: 403 Forbidden
|
||||
|
||||
**Cause:** No write access to repository
|
||||
|
||||
**Solutions:**
|
||||
1. Check repository namespace matches your username
|
||||
2. Verify you're a member of organization (if using org namespace)
|
||||
3. Check repository isn't private (if accessing org repo)
|
||||
|
||||
### Error: Repository not found
|
||||
|
||||
**Cause:** Repository doesn't exist and auto-creation failed
|
||||
|
||||
**Solutions:**
|
||||
1. Manually create repository first
|
||||
2. Check repository name format
|
||||
3. Verify namespace exists
|
||||
|
||||
### Error: Push failed during training
|
||||
|
||||
**Cause:** Network issues or Hub unavailable
|
||||
|
||||
**Solutions:**
|
||||
1. Training continues but final push fails
|
||||
2. Checkpoints may be saved
|
||||
3. Re-run push manually after job completes
|
||||
|
||||
### Issue: Model loads but inference fails
|
||||
|
||||
**Possible causes:**
|
||||
1. Image processor not saved—verify it's pushed separately
|
||||
2. Label mappings missing—check config.json has id2label
|
||||
3. Wrong image size—verify image processor matches training config
|
||||
|
||||
### Issue: Model saved but not visible
|
||||
|
||||
**Possible causes:**
|
||||
1. Repository is private—check https://huggingface.co/username
|
||||
2. Wrong namespace—verify `hub_model_id` matches login
|
||||
3. Push still in progress—wait a few minutes
|
||||
|
||||
## Manual Push After Training
|
||||
|
||||
If training completes but push fails, push manually:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForObjectDetection, AutoImageProcessor
|
||||
|
||||
# Load from local checkpoint
|
||||
model = AutoModelForObjectDetection.from_pretrained("./output_dir")
|
||||
image_processor = AutoImageProcessor.from_pretrained("./output_dir")
|
||||
|
||||
# Push to Hub
|
||||
model.push_to_hub("username/model-name", token="hf_abc123...")
|
||||
image_processor.push_to_hub("username/model-name", token="hf_abc123...")
|
||||
```
|
||||
|
||||
**Note:** Only possible if job hasn't completed (files still exist).
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always enable `push_to_hub=True`**
|
||||
2. **Save image processor separately** - critical for inference
|
||||
3. **Configure label mappings** before training
|
||||
4. **Use checkpoint saving** for long training runs
|
||||
5. **Verify Hub push** in logs before job completes
|
||||
6. **Set appropriate `save_total_limit`** to avoid excessive checkpoints
|
||||
7. **Use descriptive repo names** (e.g., `detr-cppe5` not `detector1`)
|
||||
8. **Add model card** with:
|
||||
- Training dataset
|
||||
- Evaluation metrics (mAP, IoU)
|
||||
- Example usage code
|
||||
- Limitations
|
||||
9. **Tag models appropriately**:
|
||||
- `object-detection`
|
||||
- Architecture: `detr`, `yolos`, `deta`
|
||||
- Dataset: `coco`, `voc`, `cppe-5`
|
||||
|
||||
## Monitoring Push Progress
|
||||
|
||||
Check logs for push progress:
|
||||
|
||||
```python
|
||||
hf_jobs("logs", {"job_id": "your-job-id"})
|
||||
```
|
||||
|
||||
**Look for:**
|
||||
```
|
||||
Pushing model to username/detector-name...
|
||||
Upload file pytorch_model.bin: 100%
|
||||
✅ Model pushed successfully
|
||||
Pushing image processor...
|
||||
✅ Image processor pushed successfully
|
||||
```
|
||||
|
||||
## Example: Full Production Setup
|
||||
|
||||
```python
|
||||
# production_detector.py
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers>=4.30.0",
|
||||
# "torch>=2.0.0",
|
||||
# "torchvision>=0.15.0",
|
||||
# "datasets>=2.12.0",
|
||||
# "evaluate>=0.4.0"
|
||||
# ]
|
||||
# ///
|
||||
|
||||
from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoModelForObjectDetection,
|
||||
TrainingArguments,
|
||||
Trainer
|
||||
)
|
||||
from datasets import load_dataset
|
||||
import os
|
||||
import torch
|
||||
|
||||
# Configuration
|
||||
MODEL_NAME = "facebook/detr-resnet-50"
|
||||
DATASET_NAME = "cppe-5"
|
||||
HUB_MODEL_ID = "myusername/detr-cppe5-detector"
|
||||
NUM_CLASSES = 5
|
||||
|
||||
# Class labels
|
||||
id2label = {0: "Coverall", 1: "Face_Shield", 2: "Gloves", 3: "Goggles", 4: "Mask"}
|
||||
label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
print(f"🔧 Loading dataset: {DATASET_NAME}")
|
||||
dataset = load_dataset(DATASET_NAME, split="train")
|
||||
print(f"✅ Dataset loaded: {len(dataset)} examples")
|
||||
|
||||
print(f"🔧 Loading model: {MODEL_NAME}")
|
||||
image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
|
||||
model = AutoModelForObjectDetection.from_pretrained(
|
||||
MODEL_NAME,
|
||||
num_labels=NUM_CLASSES,
|
||||
id2label=id2label,
|
||||
label2id=label2id,
|
||||
ignore_mismatched_sizes=True
|
||||
)
|
||||
print("✅ Model loaded")
|
||||
|
||||
# Configure with comprehensive Hub settings
|
||||
training_args = TrainingArguments(
|
||||
output_dir="detr-cppe5",
|
||||
|
||||
# Hub configuration
|
||||
push_to_hub=True,
|
||||
hub_model_id=HUB_MODEL_ID,
|
||||
hub_strategy="checkpoint", # Push checkpoints
|
||||
|
||||
# Checkpoint configuration
|
||||
save_strategy="steps",
|
||||
save_steps=500,
|
||||
save_total_limit=3,
|
||||
|
||||
# Training settings
|
||||
num_train_epochs=10,
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=2,
|
||||
learning_rate=1e-4,
|
||||
warmup_steps=500,
|
||||
|
||||
# Evaluation
|
||||
eval_strategy="steps",
|
||||
eval_steps=500,
|
||||
|
||||
# Logging
|
||||
logging_steps=50,
|
||||
logging_first_step=True,
|
||||
|
||||
# Performance
|
||||
fp16=True, # Mixed precision training
|
||||
dataloader_num_workers=4,
|
||||
)
|
||||
|
||||
# ✅ CRITICAL: Authenticate with Hub BEFORE creating Trainer
|
||||
# login() saves the token globally so ALL hub operations can find it.
|
||||
from huggingface_hub import login
|
||||
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("hfjob")
|
||||
if hf_token:
|
||||
login(token=hf_token)
|
||||
training_args.hub_token = hf_token
|
||||
elif training_args.push_to_hub:
|
||||
raise ValueError("HF_TOKEN not found! Add secrets={'HF_TOKEN': '$HF_TOKEN'} to job config.")
|
||||
|
||||
# Data collator
|
||||
def collate_fn(batch):
|
||||
pixel_values = [item["pixel_values"] for item in batch]
|
||||
labels = [item["labels"] for item in batch]
|
||||
encoding = image_processor.pad(pixel_values, return_tensors="pt")
|
||||
return {
|
||||
"pixel_values": encoding["pixel_values"],
|
||||
"labels": labels
|
||||
}
|
||||
|
||||
# Create trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
data_collator=collate_fn,
|
||||
)
|
||||
|
||||
print("🚀 Starting training...")
|
||||
trainer.train()
|
||||
|
||||
print("💾 Pushing final model to Hub...")
|
||||
trainer.push_to_hub(
|
||||
commit_message="Upload trained DETR model on CPPE-5",
|
||||
tags=["object-detection", "detr", "cppe-5", "vision"],
|
||||
)
|
||||
|
||||
print("💾 Pushing image processor to Hub...")
|
||||
image_processor.push_to_hub(
|
||||
repo_id=HUB_MODEL_ID,
|
||||
commit_message="Upload image processor"
|
||||
)
|
||||
|
||||
print("✅ Training complete!")
|
||||
print(f"Model available at: https://huggingface.co/{HUB_MODEL_ID}")
|
||||
print(f"\nTo use your model:")
|
||||
print(f"```python")
|
||||
print(f"from transformers import AutoImageProcessor, AutoModelForObjectDetection")
|
||||
print(f"")
|
||||
print(f"processor = AutoImageProcessor.from_pretrained('{HUB_MODEL_ID}')")
|
||||
print(f"model = AutoModelForObjectDetection.from_pretrained('{HUB_MODEL_ID}')")
|
||||
print(f"```")
|
||||
```
|
||||
|
||||
**Submit:**
|
||||
|
||||
```python
|
||||
hf_jobs("uv", {
|
||||
"script": training_script_content, # Pass script content as a string, NOT a filename
|
||||
"flavor": "a10g-large",
|
||||
"timeout": "8h",
|
||||
"secrets": {"HF_TOKEN": "$HF_TOKEN"}
|
||||
})
|
||||
```
|
||||
|
||||
## Inference Example
|
||||
|
||||
After training, use your model:
|
||||
|
||||
```python
|
||||
from transformers import AutoImageProcessor, AutoModelForObjectDetection
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
# Load model from Hub
|
||||
processor = AutoImageProcessor.from_pretrained("username/detr-cppe5-detector")
|
||||
model = AutoModelForObjectDetection.from_pretrained("username/detr-cppe5-detector")
|
||||
|
||||
# Load and process image
|
||||
image = Image.open("test_image.jpg")
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
# Run inference
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Post-process results
|
||||
target_sizes = torch.tensor([image.size[::-1]])
|
||||
results = processor.post_process_object_detection(
|
||||
outputs,
|
||||
threshold=0.5,
|
||||
target_sizes=target_sizes
|
||||
)[0]
|
||||
|
||||
# Print detections
|
||||
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
||||
box = [round(i, 2) for i in box.tolist()]
|
||||
print(
|
||||
f"Detected {model.config.id2label[label.item()]} with confidence "
|
||||
f"{round(score.item(), 3)} at location {box}"
|
||||
)
|
||||
```
|
||||
|
||||
## Key Takeaway
|
||||
|
||||
**Without `push_to_hub=True` and `secrets={"HF_TOKEN": "$HF_TOKEN"}`, all training results are permanently lost.**
|
||||
|
||||
**For object detection, also remember to:**
|
||||
1. Save the image processor separately
|
||||
2. Configure label mappings (id2label, label2id)
|
||||
3. Include appropriate model card metadata
|
||||
|
||||
Always verify all three are configured before submitting any training job.
|
||||
@@ -0,0 +1,279 @@
|
||||
# Image classification
|
||||
|
||||
## Contents
|
||||
- Load Food-101 dataset
|
||||
- Preprocess (ViT image processor, torchvision transforms)
|
||||
- Evaluate (accuracy metric, compute_metrics)
|
||||
- Train (TrainingArguments, Trainer setup, push to Hub)
|
||||
- Inference (pipeline, manual prediction)
|
||||
|
||||
---
|
||||
|
||||
Image classification assigns a label or class to an image. Unlike text or audio classification, the inputs are the
|
||||
pixel values that comprise an image. There are many applications for image classification, such as detecting damage
|
||||
after a natural disaster, monitoring crop health, or helping screen medical images for signs of disease.
|
||||
|
||||
This guide illustrates how to:
|
||||
|
||||
1. Fine-tune [ViT](../model_doc/vit) on the [Food-101](https://huggingface.co/datasets/ethz/food101) dataset to classify a food item in an image.
|
||||
2. Use your fine-tuned model for inference.
|
||||
|
||||
To see all architectures and checkpoints compatible with this task, we recommend checking the [task-page](https://huggingface.co/tasks/image-classification)
|
||||
|
||||
Before you begin, make sure you have all the necessary libraries installed:
|
||||
|
||||
```bash
|
||||
pip install transformers datasets evaluate accelerate pillow torchvision scikit-learn trackio
|
||||
```
|
||||
|
||||
We encourage you to log in to your Hugging Face account to upload and share your model with the community. When prompted, enter your token to log in:
|
||||
|
||||
```py
|
||||
>>> from huggingface_hub import notebook_login
|
||||
|
||||
>>> notebook_login()
|
||||
```
|
||||
|
||||
## Load Food-101 dataset
|
||||
|
||||
Start by loading a smaller subset of the Food-101 dataset from the 🤗 Datasets library. This will give you a chance to
|
||||
experiment and make sure everything works before spending more time training on the full dataset.
|
||||
|
||||
```py
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> food = load_dataset("ethz/food101", split="train[:5000]")
|
||||
```
|
||||
|
||||
Split the dataset's `train` split into a train and test set with the [train_test_split](https://huggingface.co/docs/datasets/v4.5.0/en/package_reference/main_classes#datasets.Dataset.train_test_split) method:
|
||||
|
||||
```py
|
||||
>>> food = food.train_test_split(test_size=0.2)
|
||||
```
|
||||
|
||||
Then take a look at an example:
|
||||
|
||||
```py
|
||||
>>> food["train"][0]
|
||||
{'image': ,
|
||||
'label': 79}
|
||||
```
|
||||
|
||||
Each example in the dataset has two fields:
|
||||
|
||||
- `image`: a PIL image of the food item
|
||||
- `label`: the label class of the food item
|
||||
|
||||
To make it easier for the model to get the label name from the label id, create a dictionary that maps the label name
|
||||
to an integer and vice versa:
|
||||
|
||||
```py
|
||||
>>> labels = food["train"].features["label"].names
|
||||
>>> label2id, id2label = dict(), dict()
|
||||
>>> for i, label in enumerate(labels):
|
||||
... label2id[label] = str(i)
|
||||
... id2label[str(i)] = label
|
||||
```
|
||||
|
||||
Now you can convert the label id to a label name:
|
||||
|
||||
```py
|
||||
>>> id2label[str(79)]
|
||||
'prime_rib'
|
||||
```
|
||||
|
||||
## Preprocess
|
||||
|
||||
The next step is to load a ViT image processor to process the image into a tensor:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoImageProcessor
|
||||
|
||||
>>> checkpoint = "google/vit-base-patch16-224-in21k"
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
|
||||
```
|
||||
|
||||
Apply some image transformations to the images to make the model more robust against overfitting. Here you'll use torchvision's [`transforms`](https://pytorch.org/vision/stable/transforms.html) module, but you can also use any image library you like.
|
||||
|
||||
Crop a random part of the image, resize it, and normalize it with the image mean and standard deviation:
|
||||
|
||||
```py
|
||||
>>> from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
|
||||
|
||||
>>> normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
||||
>>> size = (
|
||||
... image_processor.size["shortest_edge"]
|
||||
... if "shortest_edge" in image_processor.size
|
||||
... else (image_processor.size["height"], image_processor.size["width"])
|
||||
... )
|
||||
>>> _transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
|
||||
```
|
||||
|
||||
Then create a preprocessing function to apply the transforms and return the `pixel_values` - the inputs to the model - of the image:
|
||||
|
||||
```py
|
||||
>>> def transforms(examples):
|
||||
... examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
|
||||
... del examples["image"]
|
||||
... return examples
|
||||
```
|
||||
|
||||
To apply the preprocessing function over the entire dataset, use 🤗 Datasets [with_transform](https://huggingface.co/docs/datasets/v4.5.0/en/package_reference/main_classes#datasets.Dataset.with_transform) method. The transforms are applied on the fly when you load an element of the dataset:
|
||||
|
||||
```py
|
||||
>>> food = food.with_transform(transforms)
|
||||
```
|
||||
|
||||
Now create a batch of examples using [DefaultDataCollator](/docs/transformers/v5.2.0/en/main_classes/data_collator#transformers.DefaultDataCollator). Unlike other data collators in 🤗 Transformers, the `DefaultDataCollator` does not apply additional preprocessing such as padding.
|
||||
|
||||
```py
|
||||
>>> from transformers import DefaultDataCollator
|
||||
|
||||
>>> data_collator = DefaultDataCollator()
|
||||
```
|
||||
|
||||
## Evaluate
|
||||
|
||||
Including a metric during training is often helpful for evaluating your model's performance. You can quickly load an
|
||||
evaluation method with the 🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) library. For this task, load
|
||||
the [accuracy](https://huggingface.co/spaces/evaluate-metric/accuracy) metric (see the 🤗 Evaluate [quick tour](https://huggingface.co/docs/evaluate/a_quick_tour) to learn more about how to load and compute a metric):
|
||||
|
||||
```py
|
||||
>>> import evaluate
|
||||
|
||||
>>> accuracy = evaluate.load("accuracy")
|
||||
```
|
||||
|
||||
Then create a function that passes your predictions and labels to [compute](https://huggingface.co/docs/evaluate/v0.4.6/en/package_reference/main_classes#evaluate.EvaluationModule.compute) to calculate the accuracy:
|
||||
|
||||
```py
|
||||
>>> import numpy as np
|
||||
|
||||
>>> def compute_metrics(eval_pred):
|
||||
... predictions, labels = eval_pred
|
||||
... predictions = np.argmax(predictions, axis=1)
|
||||
... return accuracy.compute(predictions=predictions, references=labels)
|
||||
```
|
||||
|
||||
Your `compute_metrics` function is ready to go now, and you'll return to it when you set up your training.
|
||||
|
||||
## Train
|
||||
|
||||
If you aren't familiar with finetuning a model with the [Trainer](/docs/transformers/v5.2.0/en/main_classes/trainer#transformers.Trainer), take a look at the basic tutorial [here](../training#train-with-pytorch-trainer)!
|
||||
|
||||
You're ready to start training your model now! Load ViT with [AutoModelForImageClassification](/docs/transformers/v5.2.0/en/model_doc/auto#transformers.AutoModelForImageClassification). Specify the number of labels along with the number of expected labels, and the label mappings:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
|
||||
|
||||
>>> model = AutoModelForImageClassification.from_pretrained(
|
||||
... checkpoint,
|
||||
... num_labels=len(labels),
|
||||
... id2label=id2label,
|
||||
... label2id=label2id,
|
||||
... )
|
||||
```
|
||||
|
||||
At this point, only three steps remain:
|
||||
|
||||
1. Define your training hyperparameters in [TrainingArguments](/docs/transformers/v5.2.0/en/main_classes/trainer#transformers.TrainingArguments). It is important you don't remove unused columns because that'll drop the `image` column. Without the `image` column, you can't create `pixel_values`. Set `remove_unused_columns=False` to prevent this behavior! The only other required parameter is `output_dir` which specifies where to save your model. You'll push this model to the Hub by setting `push_to_hub=True` (you need to be signed in to Hugging Face to upload your model). At the end of each epoch, the [Trainer](/docs/transformers/v5.2.0/en/main_classes/trainer#transformers.Trainer) will evaluate the accuracy and save the training checkpoint.
|
||||
2. Pass the training arguments to [Trainer](/docs/transformers/v5.2.0/en/main_classes/trainer#transformers.Trainer) along with the model, dataset, tokenizer, data collator, and `compute_metrics` function.
|
||||
3. Call [train()](/docs/transformers/v5.2.0/en/main_classes/trainer#transformers.Trainer.train) to finetune your model.
|
||||
|
||||
```py
|
||||
>>> training_args = TrainingArguments(
|
||||
... output_dir="my_awesome_food_model",
|
||||
... remove_unused_columns=False,
|
||||
... eval_strategy="epoch",
|
||||
... save_strategy="epoch",
|
||||
... learning_rate=5e-5,
|
||||
... per_device_train_batch_size=16,
|
||||
... gradient_accumulation_steps=4,
|
||||
... per_device_eval_batch_size=16,
|
||||
... num_train_epochs=3,
|
||||
... warmup_steps=0.1,
|
||||
... logging_steps=10,
|
||||
... report_to="trackio",
|
||||
... run_name="food101",
|
||||
... load_best_model_at_end=True,
|
||||
... metric_for_best_model="accuracy",
|
||||
... push_to_hub=True,
|
||||
... )
|
||||
|
||||
>>> trainer = Trainer(
|
||||
... model=model,
|
||||
... args=training_args,
|
||||
... data_collator=data_collator,
|
||||
... train_dataset=food["train"],
|
||||
... eval_dataset=food["test"],
|
||||
... processing_class=image_processor,
|
||||
... compute_metrics=compute_metrics,
|
||||
... )
|
||||
|
||||
>>> trainer.train()
|
||||
```
|
||||
|
||||
Once training is completed, share your model to the Hub with the [push_to_hub()](/docs/transformers/v5.2.0/en/main_classes/trainer#transformers.Trainer.push_to_hub) method so everyone can use your model:
|
||||
|
||||
```py
|
||||
>>> trainer.push_to_hub()
|
||||
```
|
||||
|
||||
For a more in-depth example of how to finetune a model for image classification, take a look at the corresponding [PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).
|
||||
|
||||
## Inference
|
||||
|
||||
Great, now that you've fine-tuned a model, you can use it for inference!
|
||||
|
||||
Load an image you'd like to run inference on:
|
||||
|
||||
```py
|
||||
>>> ds = load_dataset("ethz/food101", split="validation[:10]")
|
||||
>>> image = ds["image"][0]
|
||||
```
|
||||
|
||||
|
||||
|
||||
The simplest way to try out your finetuned model for inference is to use it in a [pipeline()](/docs/transformers/v5.2.0/en/main_classes/pipelines#transformers.pipeline). Instantiate a `pipeline` for image classification with your model, and pass your image to it:
|
||||
|
||||
```py
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> classifier = pipeline("image-classification", model="my_awesome_food_model")
|
||||
>>> classifier(image)
|
||||
[{'score': 0.31856709718704224, 'label': 'beignets'},
|
||||
{'score': 0.015232225880026817, 'label': 'bruschetta'},
|
||||
{'score': 0.01519392803311348, 'label': 'chicken_wings'},
|
||||
{'score': 0.013022331520915031, 'label': 'pork_chop'},
|
||||
{'score': 0.012728818692266941, 'label': 'prime_rib'}]
|
||||
```
|
||||
|
||||
You can also manually replicate the results of the `pipeline` if you'd like:
|
||||
|
||||
Load an image processor to preprocess the image and return the `input` as PyTorch tensors:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoImageProcessor
|
||||
>>> import torch
|
||||
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained("my_awesome_food_model")
|
||||
>>> inputs = image_processor(image, return_tensors="pt")
|
||||
```
|
||||
|
||||
Pass your inputs to the model and return the logits:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoModelForImageClassification
|
||||
|
||||
>>> model = AutoModelForImageClassification.from_pretrained("my_awesome_food_model")
|
||||
>>> with torch.no_grad():
|
||||
... logits = model(**inputs).logits
|
||||
```
|
||||
|
||||
Get the predicted label with the highest probability, and use the model's `id2label` mapping to convert it to a label:
|
||||
|
||||
```py
|
||||
>>> predicted_label = logits.argmax(-1).item()
|
||||
>>> model.config.id2label[predicted_label]
|
||||
'beignets'
|
||||
```
|
||||
@@ -0,0 +1,700 @@
|
||||
# Object Detection Training Reference
|
||||
|
||||
## Contents
|
||||
- Load the CPPE-5 dataset
|
||||
- Preprocess the data (augmentation with Albumentations, COCO annotation formatting)
|
||||
- Preparing function to compute mAP
|
||||
- Training the detection model (TrainingArguments, Trainer setup)
|
||||
- Evaluate
|
||||
- Inference (loading from Hub, running predictions, visualizing results)
|
||||
|
||||
---
|
||||
|
||||
Object detection is the computer vision task of detecting instances (such as humans, buildings, or cars) in an image. Object detection models receive an image as input and output
|
||||
coordinates of the bounding boxes and associated labels of the detected objects. An image can contain multiple objects,
|
||||
each with its own bounding box and a label (e.g. it can have a car and a building), and each object can
|
||||
be present in different parts of an image (e.g. the image can have several cars).
|
||||
This task is commonly used in autonomous driving for detecting things like pedestrians, road signs, and traffic lights.
|
||||
Other applications include counting objects in images, image search, and more.
|
||||
|
||||
In this guide, you will learn how to:
|
||||
|
||||
1. Finetune [DETR](https://huggingface.co/docs/transformers/model_doc/detr), a model that combines a convolutional
|
||||
backbone with an encoder-decoder Transformer, on the [CPPE-5](https://huggingface.co/datasets/cppe-5)
|
||||
dataset.
|
||||
2. Use your finetuned model for inference.
|
||||
|
||||
To see all architectures and checkpoints compatible with this task, we recommend checking the [task-page](https://huggingface.co/tasks/object-detection)
|
||||
|
||||
Before you begin, make sure you have all the necessary libraries installed:
|
||||
|
||||
```bash
|
||||
pip install -q datasets transformers accelerate timm trackio
|
||||
pip install -q -U albumentations>=1.4.5 torchmetrics pycocotools
|
||||
```
|
||||
|
||||
You'll use 🤗 Datasets to load a dataset from the Hugging Face Hub, 🤗 Transformers to train your model,
|
||||
and `albumentations` to augment the data.
|
||||
|
||||
We encourage you to share your model with the community. Log in to your Hugging Face account to upload it to the Hub.
|
||||
When prompted, enter your token to log in:
|
||||
|
||||
```py
|
||||
>>> from huggingface_hub import notebook_login
|
||||
|
||||
>>> notebook_login()
|
||||
```
|
||||
|
||||
To get started, we'll define global constants, namely the model name and image size. For this tutorial, we'll use the conditional DETR model due to its faster convergence. Feel free to select any object detection model available in the `transformers` library.
|
||||
|
||||
```py
|
||||
>>> MODEL_NAME = "microsoft/conditional-detr-resnet-50" # or "facebook/detr-resnet-50"
|
||||
>>> IMAGE_SIZE = 480
|
||||
```
|
||||
|
||||
## Load the CPPE-5 dataset
|
||||
|
||||
The [CPPE-5 dataset](https://huggingface.co/datasets/cppe-5) contains images with
|
||||
annotations identifying medical personal protective equipment (PPE) in the context of the COVID-19 pandemic.
|
||||
|
||||
Start by loading the dataset and creating a `validation` split from `train`:
|
||||
|
||||
```py
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> cppe5 = load_dataset("cppe-5")
|
||||
|
||||
>>> if "validation" not in cppe5:
|
||||
... split = cppe5["train"].train_test_split(0.15, seed=1337)
|
||||
... cppe5["train"] = split["train"]
|
||||
... cppe5["validation"] = split["test"]
|
||||
|
||||
>>> cppe5
|
||||
DatasetDict({
|
||||
train: Dataset({
|
||||
features: ['image_id', 'image', 'width', 'height', 'objects'],
|
||||
num_rows: 850
|
||||
})
|
||||
test: Dataset({
|
||||
features: ['image_id', 'image', 'width', 'height', 'objects'],
|
||||
num_rows: 29
|
||||
})
|
||||
validation: Dataset({
|
||||
features: ['image_id', 'image', 'width', 'height', 'objects'],
|
||||
num_rows: 150
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
You'll see that this dataset has 1000 images for train and validation sets and a test set with 29 images.
|
||||
|
||||
To get familiar with the data, explore what the examples look like.
|
||||
|
||||
```py
|
||||
>>> cppe5["train"][0]
|
||||
{
|
||||
'image_id': 366,
|
||||
'image': ,
|
||||
'width': 500,
|
||||
'height': 500,
|
||||
'objects': {
|
||||
'id': [1932, 1933, 1934],
|
||||
'area': [27063, 34200, 32431],
|
||||
'bbox': [[29.0, 11.0, 97.0, 279.0],
|
||||
[201.0, 1.0, 120.0, 285.0],
|
||||
[382.0, 0.0, 113.0, 287.0]],
|
||||
'category': [0, 0, 0]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The examples in the dataset have the following fields:
|
||||
|
||||
- `image_id`: the example image id
|
||||
- `image`: a `PIL.Image.Image` object containing the image
|
||||
- `width`: width of the image
|
||||
- `height`: height of the image
|
||||
- `objects`: a dictionary containing bounding box metadata for the objects in the image:
|
||||
- `id`: the annotation id
|
||||
- `area`: the area of the bounding box
|
||||
- `bbox`: the object's bounding box (in the [COCO format](https://albumentations.ai/docs/getting_started/bounding_boxes_augmentation/#coco) )
|
||||
- `category`: the object's category, with possible values including `Coverall (0)`, `Face_Shield (1)`, `Gloves (2)`, `Goggles (3)` and `Mask (4)`
|
||||
|
||||
You may notice that the `bbox` field follows the COCO format, which is the format that the DETR model expects.
|
||||
However, the grouping of the fields inside `objects` differs from the annotation format DETR requires. You will
|
||||
need to apply some preprocessing transformations before using this data for training.
|
||||
|
||||
To get an even better understanding of the data, visualize an example in the dataset.
|
||||
|
||||
```py
|
||||
>>> import numpy as np
|
||||
>>> import os
|
||||
>>> from PIL import Image, ImageDraw
|
||||
|
||||
>>> image = cppe5["train"][2]["image"]
|
||||
>>> annotations = cppe5["train"][2]["objects"]
|
||||
>>> draw = ImageDraw.Draw(image)
|
||||
|
||||
>>> categories = cppe5["train"].features["objects"]["category"].feature.names
|
||||
|
||||
>>> id2label = {index: x for index, x in enumerate(categories, start=0)}
|
||||
>>> label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
>>> for i in range(len(annotations["id"])):
|
||||
... box = annotations["bbox"][i]
|
||||
... class_idx = annotations["category"][i]
|
||||
... x, y, w, h = tuple(box)
|
||||
... # Check if coordinates are normalized or not
|
||||
... if max(box) > 1.0:
|
||||
... # Coordinates are un-normalized, no need to re-scale them
|
||||
... x1, y1 = int(x), int(y)
|
||||
... x2, y2 = int(x + w), int(y + h)
|
||||
... else:
|
||||
... # Coordinates are normalized, re-scale them
|
||||
... x1 = int(x * width)
|
||||
... y1 = int(y * height)
|
||||
... x2 = int((x + w) * width)
|
||||
... y2 = int((y + h) * height)
|
||||
... draw.rectangle((x, y, x + w, y + h), outline="red", width=1)
|
||||
... draw.text((x, y), id2label[class_idx], fill="white")
|
||||
|
||||
>>> image
|
||||
```
|
||||
|
||||
|
||||
|
||||
To visualize the bounding boxes with associated labels, you can get the labels from the dataset's metadata, specifically
|
||||
the `category` field.
|
||||
You'll also want to create dictionaries that map a label id to a label class (`id2label`) and the other way around (`label2id`).
|
||||
You can use them later when setting up the model. Including these maps will make your model reusable by others if you share
|
||||
it on the Hugging Face Hub. Please note that, the part of above code that draws the bounding boxes assume that it is in `COCO` format `(x_min, y_min, width, height)`. It has to be adjusted to work for other formats like `(x_min, y_min, x_max, y_max)`.
|
||||
|
||||
As a final step of getting familiar with the data, explore it for potential issues. One common problem with datasets for
|
||||
object detection is bounding boxes that "stretch" beyond the edge of the image. Such "runaway" bounding boxes can raise
|
||||
errors during training and should be addressed. There are a few examples with this issue in this dataset.
|
||||
To keep things simple in this guide, we will set `clip=True` for `BboxParams` in transformations below.
|
||||
|
||||
## Preprocess the data
|
||||
|
||||
To finetune a model, you must preprocess the data you plan to use to match precisely the approach used for the pre-trained model.
|
||||
[AutoImageProcessor](/docs/transformers/v5.1.0/en/model_doc/auto#transformers.AutoImageProcessor) takes care of processing image data to create `pixel_values`, `pixel_mask`, and
|
||||
`labels` that a DETR model can train with. The image processor has some attributes that you won't have to worry about:
|
||||
|
||||
- `image_mean = [0.485, 0.456, 0.406 ]`
|
||||
- `image_std = [0.229, 0.224, 0.225]`
|
||||
|
||||
These are the mean and standard deviation used to normalize images during the model pre-training. These values are crucial
|
||||
to replicate when doing inference or finetuning a pre-trained image model.
|
||||
|
||||
Instantiate the image processor from the same checkpoint as the model you want to finetune.
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoImageProcessor
|
||||
|
||||
>>> MAX_SIZE = IMAGE_SIZE
|
||||
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained(
|
||||
... MODEL_NAME,
|
||||
... do_resize=True,
|
||||
... size={"max_height": MAX_SIZE, "max_width": MAX_SIZE},
|
||||
... do_pad=True,
|
||||
... pad_size={"height": MAX_SIZE, "width": MAX_SIZE},
|
||||
... )
|
||||
```
|
||||
|
||||
Before passing the images to the `image_processor`, apply two preprocessing transformations to the dataset:
|
||||
|
||||
- Augmenting images
|
||||
- Reformatting annotations to meet DETR expectations
|
||||
|
||||
First, to make sure the model does not overfit on the training data, you can apply image augmentation with any data augmentation library. Here we use [Albumentations](https://albumentations.ai/docs/).
|
||||
This library ensures that transformations affect the image and update the bounding boxes accordingly.
|
||||
The 🤗 Datasets library documentation has a detailed [guide on how to augment images for object detection](https://huggingface.co/docs/datasets/object_detection),
|
||||
and it uses the exact same dataset as an example. Apply some geometric and color transformations to the image. For additional augmentation options, explore the [Albumentations Demo Space](https://huggingface.co/spaces/qubvel-hf/albumentations-demo).
|
||||
|
||||
```py
|
||||
>>> import albumentations as A
|
||||
|
||||
>>> train_augment_and_transform = A.Compose(
|
||||
... [
|
||||
... A.Perspective(p=0.1),
|
||||
... A.HorizontalFlip(p=0.5),
|
||||
... A.RandomBrightnessContrast(p=0.5),
|
||||
... A.HueSaturationValue(p=0.1),
|
||||
... ],
|
||||
... bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True, min_area=25),
|
||||
... )
|
||||
|
||||
>>> validation_transform = A.Compose(
|
||||
... [A.NoOp()],
|
||||
... bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True),
|
||||
... )
|
||||
```
|
||||
|
||||
The `image_processor` expects the annotations to be in the following format: `{'image_id': int, 'annotations': list[Dict]}`,
|
||||
where each dictionary is a COCO object annotation. Let's add a function to reformat annotations for a single example:
|
||||
|
||||
```py
|
||||
>>> def format_image_annotations_as_coco(image_id, categories, areas, bboxes):
|
||||
... """Format one set of image annotations to the COCO format
|
||||
|
||||
... Args:
|
||||
... image_id (str): image id. e.g. "0001"
|
||||
... categories (list[int]): list of categories/class labels corresponding to provided bounding boxes
|
||||
... areas (list[float]): list of corresponding areas to provided bounding boxes
|
||||
... bboxes (list[tuple[float]]): list of bounding boxes provided in COCO format
|
||||
... ([center_x, center_y, width, height] in absolute coordinates)
|
||||
|
||||
... Returns:
|
||||
... dict: {
|
||||
... "image_id": image id,
|
||||
... "annotations": list of formatted annotations
|
||||
... }
|
||||
... """
|
||||
... annotations = []
|
||||
... for category, area, bbox in zip(categories, areas, bboxes):
|
||||
... formatted_annotation = {
|
||||
... "image_id": image_id,
|
||||
... "category_id": category,
|
||||
... "iscrowd": 0,
|
||||
... "area": area,
|
||||
... "bbox": list(bbox),
|
||||
... }
|
||||
... annotations.append(formatted_annotation)
|
||||
|
||||
... return {
|
||||
... "image_id": image_id,
|
||||
... "annotations": annotations,
|
||||
... }
|
||||
|
||||
```
|
||||
|
||||
Now you can combine the image and annotation transformations to use on a batch of examples:
|
||||
|
||||
```py
|
||||
>>> def augment_and_transform_batch(examples, transform, image_processor, return_pixel_mask=False):
|
||||
... """Apply augmentations and format annotations in COCO format for object detection task"""
|
||||
|
||||
... images = []
|
||||
... annotations = []
|
||||
... for image_id, image, objects in zip(examples["image_id"], examples["image"], examples["objects"]):
|
||||
... image = np.array(image.convert("RGB"))
|
||||
|
||||
... # apply augmentations
|
||||
... output = transform(image=image, bboxes=objects["bbox"], category=objects["category"])
|
||||
... images.append(output["image"])
|
||||
|
||||
... # format annotations in COCO format
|
||||
... formatted_annotations = format_image_annotations_as_coco(
|
||||
... image_id, output["category"], objects["area"], output["bboxes"]
|
||||
... )
|
||||
... annotations.append(formatted_annotations)
|
||||
|
||||
... # Apply the image processor transformations: resizing, rescaling, normalization
|
||||
... result = image_processor(images=images, annotations=annotations, return_tensors="pt")
|
||||
|
||||
... if not return_pixel_mask:
|
||||
... result.pop("pixel_mask", None)
|
||||
|
||||
... return result
|
||||
```
|
||||
|
||||
Apply this preprocessing function to the entire dataset using 🤗 Datasets [with_transform](https://huggingface.co/docs/datasets/v4.5.0/en/package_reference/main_classes#datasets.Dataset.with_transform) method. This method applies
|
||||
transformations on the fly when you load an element of the dataset.
|
||||
|
||||
At this point, you can check what an example from the dataset looks like after the transformations. You should see a tensor
|
||||
with `pixel_values`, a tensor with `pixel_mask`, and `labels`.
|
||||
|
||||
```py
|
||||
>>> from functools import partial
|
||||
|
||||
>>> # Make transform functions for batch and apply for dataset splits
|
||||
>>> train_transform_batch = partial(
|
||||
... augment_and_transform_batch, transform=train_augment_and_transform, image_processor=image_processor
|
||||
... )
|
||||
>>> validation_transform_batch = partial(
|
||||
... augment_and_transform_batch, transform=validation_transform, image_processor=image_processor
|
||||
... )
|
||||
|
||||
>>> cppe5["train"] = cppe5["train"].with_transform(train_transform_batch)
|
||||
>>> cppe5["validation"] = cppe5["validation"].with_transform(validation_transform_batch)
|
||||
>>> cppe5["test"] = cppe5["test"].with_transform(validation_transform_batch)
|
||||
|
||||
>>> cppe5["train"][15]
|
||||
{'pixel_values': tensor([[[ 1.9235, 1.9407, 1.9749, ..., -0.7822, -0.7479, -0.6965],
|
||||
[ 1.9578, 1.9749, 1.9920, ..., -0.7993, -0.7650, -0.7308],
|
||||
[ 2.0092, 2.0092, 2.0263, ..., -0.8507, -0.8164, -0.7822],
|
||||
...,
|
||||
[ 0.0741, 0.0741, 0.0741, ..., 0.0741, 0.0741, 0.0741],
|
||||
[ 0.0741, 0.0741, 0.0741, ..., 0.0741, 0.0741, 0.0741],
|
||||
[ 0.0741, 0.0741, 0.0741, ..., 0.0741, 0.0741, 0.0741]],
|
||||
|
||||
[[ 1.6232, 1.6408, 1.6583, ..., 0.8704, 1.0105, 1.1331],
|
||||
[ 1.6408, 1.6583, 1.6758, ..., 0.8529, 0.9930, 1.0980],
|
||||
[ 1.6933, 1.6933, 1.7108, ..., 0.8179, 0.9580, 1.0630],
|
||||
...,
|
||||
[ 0.2052, 0.2052, 0.2052, ..., 0.2052, 0.2052, 0.2052],
|
||||
[ 0.2052, 0.2052, 0.2052, ..., 0.2052, 0.2052, 0.2052],
|
||||
[ 0.2052, 0.2052, 0.2052, ..., 0.2052, 0.2052, 0.2052]],
|
||||
|
||||
[[ 1.8905, 1.9080, 1.9428, ..., -0.1487, -0.0964, -0.0615],
|
||||
[ 1.9254, 1.9428, 1.9603, ..., -0.1661, -0.1138, -0.0790],
|
||||
[ 1.9777, 1.9777, 1.9951, ..., -0.2010, -0.1138, -0.0790],
|
||||
...,
|
||||
[ 0.4265, 0.4265, 0.4265, ..., 0.4265, 0.4265, 0.4265],
|
||||
[ 0.4265, 0.4265, 0.4265, ..., 0.4265, 0.4265, 0.4265],
|
||||
[ 0.4265, 0.4265, 0.4265, ..., 0.4265, 0.4265, 0.4265]]]),
|
||||
'labels': {'image_id': tensor([688]), 'class_labels': tensor([3, 4, 2, 0, 0]), 'boxes': tensor([[0.4700, 0.1933, 0.1467, 0.0767],
|
||||
[0.4858, 0.2600, 0.1150, 0.1000],
|
||||
[0.4042, 0.4517, 0.1217, 0.1300],
|
||||
[0.4242, 0.3217, 0.3617, 0.5567],
|
||||
[0.6617, 0.4033, 0.5400, 0.4533]]), 'area': tensor([ 4048., 4140., 5694., 72478., 88128.]), 'iscrowd': tensor([0, 0, 0, 0, 0]), 'orig_size': tensor([480, 480])}}
|
||||
```
|
||||
|
||||
You have successfully augmented the individual images and prepared their annotations. However, preprocessing isn't
|
||||
complete yet. In the final step, create a custom `collate_fn` to batch images together.
|
||||
Pad images (which are now `pixel_values`) to the largest image in a batch, and create a corresponding `pixel_mask`
|
||||
to indicate which pixels are real (1) and which are padding (0).
|
||||
|
||||
```py
|
||||
>>> import torch
|
||||
|
||||
>>> def collate_fn(batch):
|
||||
... data = {}
|
||||
... data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
|
||||
... data["labels"] = [x["labels"] for x in batch]
|
||||
... if "pixel_mask" in batch[0]:
|
||||
... data["pixel_mask"] = torch.stack([x["pixel_mask"] for x in batch])
|
||||
... return data
|
||||
|
||||
```
|
||||
|
||||
## Preparing function to compute mAP
|
||||
|
||||
Object detection models are commonly evaluated with a set of COCO-style metrics. We are going to use `torchmetrics` to compute `mAP` (mean average precision) and `mAR` (mean average recall) metrics and will wrap it to `compute_metrics` function in order to use in [Trainer](/docs/transformers/v5.1.0/en/main_classes/trainer#transformers.Trainer) for evaluation.
|
||||
|
||||
Intermediate format of boxes used for training is `YOLO` (normalized) but we will compute metrics for boxes in `Pascal VOC` (absolute) format in order to correctly handle box areas. Let's define a function that converts bounding boxes to `Pascal VOC` format:
|
||||
|
||||
```py
|
||||
>>> from transformers.image_transforms import center_to_corners_format
|
||||
|
||||
>>> def convert_bbox_yolo_to_pascal(boxes, image_size):
|
||||
... """
|
||||
... Convert bounding boxes from YOLO format (x_center, y_center, width, height) in range [0, 1]
|
||||
... to Pascal VOC format (x_min, y_min, x_max, y_max) in absolute coordinates.
|
||||
|
||||
... Args:
|
||||
... boxes (torch.Tensor): Bounding boxes in YOLO format
|
||||
... image_size (tuple[int, int]): Image size in format (height, width)
|
||||
|
||||
... Returns:
|
||||
... torch.Tensor: Bounding boxes in Pascal VOC format (x_min, y_min, x_max, y_max)
|
||||
... """
|
||||
... # convert center to corners format
|
||||
... boxes = center_to_corners_format(boxes)
|
||||
|
||||
... # convert to absolute coordinates
|
||||
... height, width = image_size
|
||||
... boxes = boxes * torch.tensor([[width, height, width, height]])
|
||||
|
||||
... return boxes
|
||||
```
|
||||
|
||||
Then, in `compute_metrics` function we collect `predicted` and `target` bounding boxes, scores and labels from evaluation loop results and pass it to the scoring function.
|
||||
|
||||
```py
|
||||
>>> import numpy as np
|
||||
>>> from dataclasses import dataclass
|
||||
>>> from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||||
|
||||
>>> @dataclass
|
||||
>>> class ModelOutput:
|
||||
... logits: torch.Tensor
|
||||
... pred_boxes: torch.Tensor
|
||||
|
||||
>>> @torch.no_grad()
|
||||
>>> def compute_metrics(evaluation_results, image_processor, threshold=0.0, id2label=None):
|
||||
... """
|
||||
... Compute mean average mAP, mAR and their variants for the object detection task.
|
||||
|
||||
... Args:
|
||||
... evaluation_results (EvalPrediction): Predictions and targets from evaluation.
|
||||
... threshold (float, optional): Threshold to filter predicted boxes by confidence. Defaults to 0.0.
|
||||
... id2label (Optional[dict], optional): Mapping from class id to class name. Defaults to None.
|
||||
|
||||
... Returns:
|
||||
... Mapping[str, float]: Metrics in a form of dictionary {: }
|
||||
... """
|
||||
|
||||
... predictions, targets = evaluation_results.predictions, evaluation_results.label_ids
|
||||
|
||||
... # For metric computation we need to provide:
|
||||
... # - targets in a form of list of dictionaries with keys "boxes", "labels"
|
||||
... # - predictions in a form of list of dictionaries with keys "boxes", "scores", "labels"
|
||||
|
||||
... image_sizes = []
|
||||
... post_processed_targets = []
|
||||
... post_processed_predictions = []
|
||||
|
||||
... # Collect targets in the required format for metric computation
|
||||
... for batch in targets:
|
||||
... # collect image sizes, we will need them for predictions post processing
|
||||
... batch_image_sizes = torch.tensor(np.array([x["orig_size"] for x in batch]))
|
||||
... image_sizes.append(batch_image_sizes)
|
||||
... # collect targets in the required format for metric computation
|
||||
... # boxes were converted to YOLO format needed for model training
|
||||
... # here we will convert them to Pascal VOC format (x_min, y_min, x_max, y_max)
|
||||
... for image_target in batch:
|
||||
... boxes = torch.tensor(image_target["boxes"])
|
||||
... boxes = convert_bbox_yolo_to_pascal(boxes, image_target["orig_size"])
|
||||
... labels = torch.tensor(image_target["class_labels"])
|
||||
... post_processed_targets.append({"boxes": boxes, "labels": labels})
|
||||
|
||||
... # Collect predictions in the required format for metric computation,
|
||||
... # model produce boxes in YOLO format, then image_processor convert them to Pascal VOC format
|
||||
... for batch, target_sizes in zip(predictions, image_sizes):
|
||||
... batch_logits, batch_boxes = batch[1], batch[2]
|
||||
... output = ModelOutput(logits=torch.tensor(batch_logits), pred_boxes=torch.tensor(batch_boxes))
|
||||
... post_processed_output = image_processor.post_process_object_detection(
|
||||
... output, threshold=threshold, target_sizes=target_sizes
|
||||
... )
|
||||
... post_processed_predictions.extend(post_processed_output)
|
||||
|
||||
... # Compute metrics
|
||||
... metric = MeanAveragePrecision(box_format="xyxy", class_metrics=True)
|
||||
... metric.update(post_processed_predictions, post_processed_targets)
|
||||
... metrics = metric.compute()
|
||||
|
||||
... # Replace list of per class metrics with separate metric for each class
|
||||
... classes = metrics.pop("classes")
|
||||
... map_per_class = metrics.pop("map_per_class")
|
||||
... mar_100_per_class = metrics.pop("mar_100_per_class")
|
||||
... for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
|
||||
... class_name = id2label[class_id.item()] if id2label is not None else class_id.item()
|
||||
... metrics[f"map_{class_name}"] = class_map
|
||||
... metrics[f"mar_100_{class_name}"] = class_mar
|
||||
|
||||
... metrics = {k: round(v.item(), 4) for k, v in metrics.items()}
|
||||
|
||||
... return metrics
|
||||
|
||||
>>> eval_compute_metrics_fn = partial(
|
||||
... compute_metrics, image_processor=image_processor, id2label=id2label, threshold=0.0
|
||||
... )
|
||||
```
|
||||
|
||||
## Training the detection model
|
||||
|
||||
You have done most of the heavy lifting in the previous sections, so now you are ready to train your model!
|
||||
The images in this dataset are still quite large, even after resizing. This means that finetuning this model will
|
||||
require at least one GPU.
|
||||
|
||||
Training involves the following steps:
|
||||
|
||||
1. Load the model with [AutoModelForObjectDetection](/docs/transformers/v5.1.0/en/model_doc/auto#transformers.AutoModelForObjectDetection) using the same checkpoint as in the preprocessing.
|
||||
2. Define your training hyperparameters in [TrainingArguments](/docs/transformers/v5.1.0/en/main_classes/trainer#transformers.TrainingArguments).
|
||||
3. Pass the training arguments to [Trainer](/docs/transformers/v5.1.0/en/main_classes/trainer#transformers.Trainer) along with the model, dataset, image processor, and data collator.
|
||||
4. Call [train()](/docs/transformers/v5.1.0/en/main_classes/trainer#transformers.Trainer.train) to finetune your model.
|
||||
|
||||
When loading the model from the same checkpoint that you used for the preprocessing, remember to pass the `label2id`
|
||||
and `id2label` maps that you created earlier from the dataset's metadata. Additionally, we specify `ignore_mismatched_sizes=True` to replace the existing classification head with a new one.
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoModelForObjectDetection
|
||||
|
||||
>>> model = AutoModelForObjectDetection.from_pretrained(
|
||||
... MODEL_NAME,
|
||||
... id2label=id2label,
|
||||
... label2id=label2id,
|
||||
... ignore_mismatched_sizes=True,
|
||||
... )
|
||||
```
|
||||
|
||||
In the [TrainingArguments](/docs/transformers/v5.1.0/en/main_classes/trainer#transformers.TrainingArguments) use `output_dir` to specify where to save your model, then configure hyperparameters as you see fit. For `num_train_epochs=30` training will take about 35 minutes in Google Colab T4 GPU, increase the number of epoch to get better results.
|
||||
|
||||
Important notes:
|
||||
|
||||
- Set `remove_unused_columns` to `False`.
|
||||
- Set `eval_do_concat_batches=False` to get proper evaluation results. Images have different number of target boxes, if batches are concatenated we will not be able to determine which boxes belongs to particular image.
|
||||
|
||||
If you wish to share your model by pushing to the Hub, set `push_to_hub` to `True` (you must be signed in to Hugging
|
||||
Face to upload your model).
|
||||
|
||||
```py
|
||||
>>> from transformers import TrainingArguments
|
||||
|
||||
>>> training_args = TrainingArguments(
|
||||
... output_dir="detr_finetuned_cppe5",
|
||||
... num_train_epochs=30,
|
||||
... fp16=False,
|
||||
... per_device_train_batch_size=8,
|
||||
... dataloader_num_workers=4,
|
||||
... learning_rate=5e-5,
|
||||
... lr_scheduler_type="cosine",
|
||||
... weight_decay=1e-4,
|
||||
... max_grad_norm=0.01,
|
||||
... metric_for_best_model="eval_map",
|
||||
... greater_is_better=True,
|
||||
... load_best_model_at_end=True,
|
||||
... eval_strategy="epoch",
|
||||
... save_strategy="epoch",
|
||||
... save_total_limit=2,
|
||||
... remove_unused_columns=False,
|
||||
... report_to="trackio",
|
||||
... run_name="cppe",
|
||||
... eval_do_concat_batches=False,
|
||||
... push_to_hub=True,
|
||||
... )
|
||||
```
|
||||
|
||||
Finally, bring everything together, and call [train()](/docs/transformers/v5.1.0/en/main_classes/trainer#transformers.Trainer.train):
|
||||
|
||||
```py
|
||||
>>> from transformers import Trainer
|
||||
|
||||
>>> trainer = Trainer(
|
||||
... model=model,
|
||||
... args=training_args,
|
||||
... train_dataset=cppe5["train"],
|
||||
... eval_dataset=cppe5["validation"],
|
||||
... processing_class=image_processor,
|
||||
... data_collator=collate_fn,
|
||||
... compute_metrics=eval_compute_metrics_fn,
|
||||
... )
|
||||
|
||||
>>> trainer.train()
|
||||
```
|
||||
|
||||
Training runs for 30 epochs (~26 minutes on a T4 GPU for CPPE-5). Final epoch 30 results:
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| Training Loss | 0.994 |
|
||||
| Validation Loss | 1.346 |
|
||||
| mAP | 0.277 |
|
||||
| mAP@50 | 0.555 |
|
||||
| mAP@75 | 0.253 |
|
||||
| mAR@100 | 0.443 |
|
||||
|
||||
Per-class mAP at epoch 30: Coverall 0.530, Face Shield 0.276, Gloves 0.175, Goggles 0.157, Mask 0.249.
|
||||
|
||||
Key observations:
|
||||
- mAP improves rapidly in early epochs (0.009 at epoch 1 → 0.18 by epoch 10), then gradually converges
|
||||
- Large objects are detected better (mAP_large=0.524) than small objects (mAP_small=0.148)
|
||||
- Class imbalance visible: Coverall highest mAP (0.530), Goggles lowest (0.157)
|
||||
|
||||
<!-- Full per-epoch training metrics table omitted for brevity. -->
|
||||
|
||||
|
||||
If you have set `push_to_hub` to `True` in the `training_args`, the training checkpoints are pushed to the
|
||||
Hugging Face Hub. Upon training completion, push the final model to the Hub as well by calling the [push_to_hub()](/docs/transformers/v5.1.0/en/main_classes/trainer#transformers.Trainer.push_to_hub) method.
|
||||
|
||||
```py
|
||||
>>> trainer.push_to_hub()
|
||||
```
|
||||
|
||||
## Evaluate
|
||||
|
||||
```py
|
||||
>>> from pprint import pprint
|
||||
|
||||
>>> metrics = trainer.evaluate(eval_dataset=cppe5["test"], metric_key_prefix="test")
|
||||
>>> pprint(metrics)
|
||||
{'epoch': 30.0,
|
||||
'test_loss': 1.0877351760864258,
|
||||
'test_map': 0.4116,
|
||||
'test_map_50': 0.741,
|
||||
'test_map_75': 0.3663,
|
||||
'test_map_Coverall': 0.5937,
|
||||
'test_map_Face_Shield': 0.5863,
|
||||
'test_map_Gloves': 0.3416,
|
||||
'test_map_Goggles': 0.1468,
|
||||
'test_map_Mask': 0.3894,
|
||||
'test_map_large': 0.5637,
|
||||
'test_map_medium': 0.3257,
|
||||
'test_map_small': 0.3589,
|
||||
'test_mar_1': 0.323,
|
||||
'test_mar_10': 0.5237,
|
||||
'test_mar_100': 0.5587,
|
||||
'test_mar_100_Coverall': 0.6756,
|
||||
'test_mar_100_Face_Shield': 0.7294,
|
||||
'test_mar_100_Gloves': 0.4721,
|
||||
'test_mar_100_Goggles': 0.4125,
|
||||
'test_mar_100_Mask': 0.5038,
|
||||
'test_mar_large': 0.7283,
|
||||
'test_mar_medium': 0.4901,
|
||||
'test_mar_small': 0.4469,
|
||||
'test_runtime': 1.6526,
|
||||
'test_samples_per_second': 17.548,
|
||||
'test_steps_per_second': 2.42}
|
||||
```
|
||||
|
||||
These results can be further improved by adjusting the hyperparameters in [TrainingArguments](/docs/transformers/v5.1.0/en/main_classes/trainer#transformers.TrainingArguments). Give it a go!
|
||||
|
||||
## Inference
|
||||
|
||||
Now that you have finetuned a model, evaluated it, and uploaded it to the Hugging Face Hub, you can use it for inference.
|
||||
|
||||
```py
|
||||
>>> import torch
|
||||
>>> import requests
|
||||
|
||||
>>> from PIL import Image, ImageDraw
|
||||
>>> from transformers import AutoImageProcessor, AutoModelForObjectDetection
|
||||
|
||||
>>> url = "https://images.pexels.com/photos/8413299/pexels-photo-8413299.jpeg?auto=compress&cs=tinysrgb&w=630&h=375&dpr=2"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
```
|
||||
|
||||
Load model and image processor from the Hugging Face Hub (skip to use already trained in this session):
|
||||
|
||||
```py
|
||||
>>> from accelerate import Accelerator
|
||||
|
||||
>>> device = Accelerator().device
|
||||
>>> model_repo = "qubvel-hf/detr_finetuned_cppe5"
|
||||
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained(model_repo)
|
||||
>>> model = AutoModelForObjectDetection.from_pretrained(model_repo)
|
||||
>>> model = model.to(device)
|
||||
```
|
||||
|
||||
And detect bounding boxes:
|
||||
|
||||
```py
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... inputs = image_processor(images=[image], return_tensors="pt")
|
||||
... outputs = model(**inputs.to(device))
|
||||
... target_sizes = torch.tensor([[image.size[1], image.size[0]]])
|
||||
... results = image_processor.post_process_object_detection(outputs, threshold=0.3, target_sizes=target_sizes)[0]
|
||||
|
||||
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
||||
... box = [round(i, 2) for i in box.tolist()]
|
||||
... print(
|
||||
... f"Detected {model.config.id2label[label.item()]} with confidence "
|
||||
... f"{round(score.item(), 3)} at location {box}"
|
||||
... )
|
||||
Detected Gloves with confidence 0.683 at location [244.58, 124.33, 300.35, 185.13]
|
||||
Detected Mask with confidence 0.517 at location [143.73, 64.58, 219.57, 125.89]
|
||||
Detected Gloves with confidence 0.425 at location [179.15, 155.57, 262.4, 226.35]
|
||||
Detected Coverall with confidence 0.407 at location [307.13, -1.18, 477.82, 318.06]
|
||||
Detected Coverall with confidence 0.391 at location [68.61, 126.66, 309.03, 318.89]
|
||||
```
|
||||
|
||||
Let's plot the result:
|
||||
|
||||
```py
|
||||
>>> draw = ImageDraw.Draw(image)
|
||||
|
||||
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
||||
... box = [round(i, 2) for i in box.tolist()]
|
||||
... x, y, x2, y2 = tuple(box)
|
||||
... draw.rectangle((x, y, x2, y2), outline="red", width=1)
|
||||
... draw.text((x, y), model.config.id2label[label.item()], fill="white")
|
||||
|
||||
>>> image
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,310 @@
|
||||
# Reliability Principles for Training Jobs
|
||||
|
||||
## Contents
|
||||
- Principle 1: Always Verify Before Use
|
||||
- Principle 2: Prioritize Reliability Over Performance
|
||||
- Principle 3: Create Atomic, Self-Contained Scripts
|
||||
- Principle 4: Provide Clear Error Context
|
||||
- Principle 5: Test the Happy Path on Known-Good Inputs
|
||||
- Summary: The Reliability Checklist (pre-flight, script quality, job config)
|
||||
- When Principles Conflict
|
||||
|
||||
---
|
||||
|
||||
These principles are derived from real production failures and successful fixes. Following them prevents common failure modes and ensures reliable job execution.
|
||||
|
||||
## Principle 1: Always Verify Before Use
|
||||
|
||||
**Rule:** Never assume repos, datasets, or resources exist. Verify with tools first.
|
||||
|
||||
### What It Prevents
|
||||
|
||||
- **Non-existent datasets** - Jobs fail immediately when dataset doesn't exist
|
||||
- **Typos in names** - Simple mistakes like "argilla-dpo-mix-7k" vs "ultrafeedback_binarized"
|
||||
- **Incorrect paths** - Old or moved repos, renamed files
|
||||
- **Missing dependencies** - Undocumented requirements
|
||||
|
||||
### How to Apply
|
||||
|
||||
**Before submitting ANY job:**
|
||||
|
||||
```python
|
||||
# Verify dataset exists
|
||||
dataset_search({"query": "dataset-name", "author": "author-name", "limit": 5})
|
||||
hub_repo_details(["author/dataset-name"], repo_type="dataset")
|
||||
|
||||
# Verify model exists
|
||||
hub_repo_details(["org/model-name"], repo_type="model")
|
||||
|
||||
# Check script/file paths (for URL-based scripts)
|
||||
# Verify before using: https://github.com/user/repo/blob/main/script.py
|
||||
```
|
||||
|
||||
**Examples that would have caught errors:**
|
||||
|
||||
```python
|
||||
# ❌ WRONG: Assumed dataset exists
|
||||
hf_jobs("uv", {
|
||||
"script": """...""",
|
||||
"env": {"DATASET": "trl-lib/argilla-dpo-mix-7k"} # Doesn't exist!
|
||||
})
|
||||
|
||||
# ✅ CORRECT: Verify first
|
||||
dataset_search({"query": "argilla dpo", "author": "trl-lib"})
|
||||
# Would show: "trl-lib/ultrafeedback_binarized" is the correct name
|
||||
|
||||
hub_repo_details(["trl-lib/ultrafeedback_binarized"], repo_type="dataset")
|
||||
# Confirms it exists before using
|
||||
```
|
||||
|
||||
### Implementation Checklist
|
||||
|
||||
- [ ] Check dataset exists before training
|
||||
- [ ] Test script URLs are valid before submitting
|
||||
- [ ] Check for recent updates/renames of resources
|
||||
- [ ] Check for dataset format
|
||||
|
||||
**Time cost:** 5-10 seconds
|
||||
**Time saved:** Hours of failed job time + debugging
|
||||
|
||||
---
|
||||
|
||||
## Principle 2: Prioritize Reliability Over Performance
|
||||
|
||||
**Rule:** Default to what is most likely to succeed, not what is theoretically fastest.
|
||||
|
||||
### What It Prevents
|
||||
|
||||
- **Hardware incompatibilities** - Features that fail on certain GPUs
|
||||
- **Unstable optimizations** - Speed-ups that cause crashes
|
||||
- **Complex configurations** - More failure points
|
||||
- **Build system issues** - Unreliable compilation methods
|
||||
|
||||
### How to Apply
|
||||
|
||||
**Choose reliability:**
|
||||
|
||||
```python
|
||||
# ❌ RISKY: Aggressive optimization that may fail
|
||||
TrainingArguments(
|
||||
torch_compile=True, # Can fail on T4, A10G GPUs
|
||||
optim="adamw_bnb_8bit", # Requires specific setup
|
||||
dataloader_num_workers=8, # May cause OOM on small instances
|
||||
...
|
||||
)
|
||||
|
||||
# ✅ SAFE: Proven defaults
|
||||
TrainingArguments(
|
||||
# torch_compile=True, # Commented with note: "Enable on H100 for 20% speedup"
|
||||
optim="adamw_torch", # Standard, always works
|
||||
fp16=True, # Stable and fast on T4/A10G
|
||||
dataloader_num_workers=4, # Conservative, reliable
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
### Real-World Example
|
||||
|
||||
**The `torch.compile` failure:**
|
||||
- Added for "20% speedup" on H100
|
||||
- **Failed fatally on T4-medium** with cryptic error
|
||||
- Misdiagnosed as dataset issue (cost hours)
|
||||
- **Fix:** Disable by default, add as optional comment
|
||||
|
||||
**Result:** Reliability > 20% performance gain
|
||||
|
||||
### Implementation Checklist
|
||||
|
||||
- [ ] Use proven, standard configurations by default
|
||||
- [ ] Comment out performance optimizations with hardware notes
|
||||
- [ ] Use stable build systems (CMake > make)
|
||||
- [ ] Test on target hardware before production
|
||||
- [ ] Document known incompatibilities
|
||||
- [ ] Provide "safe" and "fast" variants when needed
|
||||
|
||||
**Performance loss:** 10-20% in best case
|
||||
**Reliability gain:** 95%+ success rate vs 60-70%
|
||||
|
||||
---
|
||||
|
||||
## Principle 3: Create Atomic, Self-Contained Scripts
|
||||
|
||||
**Rule:** Scripts should work as complete, independent units. Don't remove parts to "simplify."
|
||||
|
||||
### What It Prevents
|
||||
|
||||
- **Missing dependencies** - Removed "unnecessary" packages that are actually required
|
||||
- **Incomplete processes** - Skipped steps that seem redundant
|
||||
- **Environment assumptions** - Scripts that need pre-setup
|
||||
- **Partial failures** - Some parts work, others fail silently
|
||||
|
||||
### How to Apply
|
||||
|
||||
**Complete dependency specifications:**
|
||||
|
||||
```python
|
||||
# ❌ INCOMPLETE: "Simplified" by removing dependencies
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers",
|
||||
# "torch",
|
||||
# "datasets",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
# ✅ COMPLETE: All dependencies explicit
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers>=5.2.0",
|
||||
# "accelerate>=1.1.0",
|
||||
# "albumentations>=1.4.16", # Required for augmentation + bbox handling
|
||||
# "timm", # Required for vision backbones
|
||||
# "datasets>=4.0",
|
||||
# "torchmetrics", # Required for mAP/mAR computation
|
||||
# "pycocotools", # Required for COCO evaluation
|
||||
# "trackio", # Required for metrics monitoring
|
||||
# "huggingface_hub",
|
||||
# ]
|
||||
# ///
|
||||
```
|
||||
|
||||
### Real-World Example
|
||||
|
||||
**The `albumentations` failure:**
|
||||
- Original script had it: augmentations and bbox clipping worked fine
|
||||
- "Simplified" version removed it: "not strictly needed for training"
|
||||
- **Training crashed on bbox augmentation** — no fallback for COCO-format bbox handling
|
||||
- Hard to debug: error appeared in data loading, not in augmentation setup
|
||||
- **Fix:** Restore all original dependencies
|
||||
|
||||
**Result:** Don't remove dependencies without thorough testing
|
||||
|
||||
### Implementation Checklist
|
||||
|
||||
- [ ] All dependencies in PEP 723 header with version pins
|
||||
- [ ] All system packages installed by script
|
||||
- [ ] No assumptions about pre-existing environment
|
||||
- [ ] No "optional" steps that are actually required
|
||||
- [ ] Test scripts in clean environment
|
||||
- [ ] Document why each dependency is needed
|
||||
|
||||
**Complexity:** Slightly longer scripts
|
||||
**Reliability:** Scripts "just work" every time
|
||||
|
||||
---
|
||||
|
||||
## Principle 4: Provide Clear Error Context
|
||||
|
||||
**Rule:** When things fail, make it obvious what went wrong and how to fix it.
|
||||
|
||||
### How to Apply
|
||||
|
||||
**Wrap subprocess calls:**
|
||||
|
||||
```python
|
||||
# ❌ UNCLEAR: Silent failure
|
||||
subprocess.run([...], check=True, capture_output=True)
|
||||
|
||||
# ✅ CLEAR: Shows what failed
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[...],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print("Warnings:", result.stderr)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ Command failed!")
|
||||
print("STDOUT:", e.stdout)
|
||||
print("STDERR:", e.stderr)
|
||||
raise
|
||||
```
|
||||
|
||||
**Validate inputs:**
|
||||
|
||||
```python
|
||||
# ❌ UNCLEAR: Fails later with cryptic error
|
||||
model = load_model(MODEL_NAME)
|
||||
|
||||
# ✅ CLEAR: Fails fast with clear message
|
||||
if not MODEL_NAME:
|
||||
raise ValueError("MODEL_NAME environment variable not set!")
|
||||
|
||||
print(f"Loading model: {MODEL_NAME}")
|
||||
try:
|
||||
model = load_model(MODEL_NAME)
|
||||
print(f"✅ Model loaded successfully")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to load model: {MODEL_NAME}")
|
||||
print(f"Error: {e}")
|
||||
print("Hint: Check that model exists on Hub")
|
||||
raise
|
||||
```
|
||||
|
||||
### Implementation Checklist
|
||||
|
||||
- [ ] Wrap external calls with try/except
|
||||
- [ ] Print stdout/stderr on failure
|
||||
- [ ] Validate environment variables early
|
||||
- [ ] Add progress indicators (✅, ❌, 🔄)
|
||||
- [ ] Include hints for common failures
|
||||
- [ ] Log configuration at start
|
||||
|
||||
---
|
||||
|
||||
## Principle 5: Test the Happy Path on Known-Good Inputs
|
||||
|
||||
**Rule:** Before using new code in production, test with inputs you know work.
|
||||
|
||||
## Summary: The Reliability Checklist
|
||||
|
||||
Before submitting ANY job:
|
||||
|
||||
### Pre-Flight Checks
|
||||
- [ ] **Verified** all repos/datasets exist (hub_repo_details)
|
||||
- [ ] **Tested** with known-good inputs if new code
|
||||
- [ ] **Using** proven hardware/configuration
|
||||
- [ ] **Included** all dependencies in PEP 723 header
|
||||
- [ ] **Installed** system requirements (build tools, etc.)
|
||||
- [ ] **Set** appropriate timeout (not default 30m)
|
||||
- [ ] **Configured** Hub push with HF_TOKEN (login() + hub_token)
|
||||
- [ ] **Added** clear error handling
|
||||
|
||||
### Script Quality
|
||||
- [ ] Self-contained (no external setup needed)
|
||||
- [ ] Complete dependencies listed
|
||||
- [ ] Build tools installed by script
|
||||
- [ ] Progress indicators included
|
||||
- [ ] Error messages are clear
|
||||
- [ ] Configuration logged at start
|
||||
|
||||
### Job Configuration
|
||||
- [ ] Timeout > expected runtime + 30% buffer
|
||||
- [ ] Hardware appropriate for model size
|
||||
- [ ] Secrets include HF_TOKEN (see SKILL.md directive #2 for syntax)
|
||||
- [ ] Script calls `login(token=hf_token)` and sets `training_args.hub_token = hf_token` BEFORE `Trainer()` init
|
||||
- [ ] Environment variables set correctly
|
||||
- [ ] Cost estimated and acceptable
|
||||
|
||||
**Following these principles transforms job success rate from ~60-70% to ~95%+**
|
||||
|
||||
---
|
||||
|
||||
## When Principles Conflict
|
||||
|
||||
Sometimes reliability and performance conflict. Here's how to choose:
|
||||
|
||||
| Scenario | Choose | Rationale |
|
||||
|----------|--------|-----------|
|
||||
| Demo/test | Reliability | Fast failure is worse than slow success |
|
||||
| Production (first run) | Reliability | Prove it works before optimizing |
|
||||
| Production (proven) | Performance | Safe to optimize after validation |
|
||||
| Time-critical | Reliability | Failures cause more delay than slow runs |
|
||||
| Cost-critical | Balanced | Test with small model, then optimize |
|
||||
|
||||
**General rule:** Reliability first, optimize second.
|
||||
|
||||
---
|
||||
@@ -0,0 +1,91 @@
|
||||
# Using timm models with Hugging Face Trainer
|
||||
|
||||
Transformers has first-class support for timm models via the `TimmWrapper` classes. You can load any timm model and use it directly with the `Trainer` API for image classification. Here's how it works:
|
||||
|
||||
## Loading a timm model
|
||||
|
||||
The `TimmWrapperForImageClassification` class (in `transformers/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py`) wraps timm models so they're fully compatible with the Trainer API. You can load them via the `Auto` classes:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForImageClassification, AutoImageProcessor, Trainer, TrainingArguments
|
||||
|
||||
# Load a timm model for image classification
|
||||
checkpoint = "timm/resnet50.a1_in1k"
|
||||
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
checkpoint,
|
||||
num_labels=10, # set to your number of classes
|
||||
ignore_mismatched_sizes=True, # needed when changing num_labels from pretrained
|
||||
)
|
||||
```
|
||||
|
||||
## Key details
|
||||
|
||||
1. **Image processor**: The `TimmWrapperImageProcessor` automatically resolves the correct transforms from timm's config. It exposes both `val_transforms` and `train_transforms` (with augmentations), as noted in the code:
|
||||
|
||||
```64:65:transformers/src/transformers/models/timm_wrapper/image_processing_timm_wrapper.py
|
||||
# useful for training, see examples/pytorch/image-classification/run_image_classification.py
|
||||
self.train_transforms = timm.data.create_transform(**self.data_config, is_training=True)
|
||||
```
|
||||
|
||||
2. **Loss computation is built-in**: `TimmWrapperForImageClassification.forward()` accepts a `labels` argument and computes cross-entropy loss automatically, which is exactly what Trainer expects:
|
||||
|
||||
```374:376:transformers/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(labels, logits, self.config)
|
||||
```
|
||||
|
||||
3. **Returns `ImageClassifierOutput`**: The output format is the standard transformers output, so Trainer handles it seamlessly.
|
||||
|
||||
## Full training example
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForImageClassification, AutoImageProcessor, Trainer, TrainingArguments
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("food101", split="train[:5000]")
|
||||
dataset = dataset.train_test_split(test_size=0.2)
|
||||
|
||||
# Load timm model + processor
|
||||
checkpoint = "timm/resnet50.a1_in1k"
|
||||
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
checkpoint,
|
||||
num_labels=101,
|
||||
ignore_mismatched_sizes=True,
|
||||
)
|
||||
|
||||
# Preprocessing
|
||||
def transform(batch):
|
||||
batch["pixel_values"] = [image_processor(img)["pixel_values"][0] for img in batch["image"]]
|
||||
batch["labels"] = batch["label"]
|
||||
return batch
|
||||
|
||||
dataset["train"].set_transform(transform)
|
||||
dataset["test"].set_transform(transform)
|
||||
|
||||
# Train
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./timm-finetuned",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=16,
|
||||
per_device_eval_batch_size=16,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
logging_steps=50,
|
||||
remove_unused_columns=False,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["test"],
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Any timm checkpoint on the Hub (prefixed with `timm/`) works out of the box (ResNet, EfficientNet, ViT, ConvNeXt, etc). The wrapper handles all the translation between timm's interface and what Trainer expects.
|
||||
@@ -0,0 +1,814 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = []
|
||||
# ///
|
||||
"""
|
||||
Dataset Format Inspector for Vision Model Training
|
||||
|
||||
Inspects Hugging Face datasets to determine compatibility with object detection
|
||||
and image classification training.
|
||||
Uses Datasets Server API for instant results - no dataset download needed!
|
||||
|
||||
ULTRA-EFFICIENT: Uses HF Datasets Server API - completes in <2 seconds.
|
||||
|
||||
Usage with HF Jobs:
|
||||
hf_jobs("uv", {
|
||||
"script": "path/to/dataset_inspector.py",
|
||||
"script_args": ["--dataset", "your/dataset", "--split", "train"]
|
||||
})
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import sys
|
||||
import json
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Inspect dataset format for vision model training")
|
||||
parser.add_argument("--dataset", type=str, required=True, help="Dataset name")
|
||||
parser.add_argument("--split", type=str, default="train", help="Dataset split (default: train)")
|
||||
parser.add_argument("--config", type=str, default="default", help="Dataset config name (default: default)")
|
||||
parser.add_argument("--preview", type=int, default=150, help="Max chars per field preview")
|
||||
parser.add_argument("--samples", type=int, default=5, help="Number of samples to fetch (default: 5)")
|
||||
parser.add_argument("--json-output", action="store_true", help="Output as JSON")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def api_request(url: str) -> Dict:
|
||||
"""Make API request to Datasets Server"""
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=10) as response:
|
||||
return json.loads(response.read().decode())
|
||||
except urllib.error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
return None
|
||||
raise Exception(f"API request failed: {e.code} {e.reason}")
|
||||
except Exception as e:
|
||||
raise Exception(f"API request failed: {str(e)}")
|
||||
|
||||
|
||||
def get_splits(dataset: str) -> Dict:
|
||||
"""Get available splits for dataset"""
|
||||
url = f"https://datasets-server.huggingface.co/splits?dataset={urllib.parse.quote(dataset)}"
|
||||
return api_request(url)
|
||||
|
||||
|
||||
def get_rows(dataset: str, config: str, split: str, offset: int = 0, length: int = 5) -> Dict:
|
||||
"""Get rows from dataset"""
|
||||
url = f"https://datasets-server.huggingface.co/rows?dataset={urllib.parse.quote(dataset)}&config={config}&split={split}&offset={offset}&length={length}"
|
||||
return api_request(url)
|
||||
|
||||
|
||||
def find_columns(columns: List[str], patterns: List[str]) -> List[str]:
|
||||
"""Find columns matching patterns"""
|
||||
return [c for c in columns if any(p in c.lower() for p in patterns)]
|
||||
|
||||
|
||||
def detect_bbox_format(bbox: List[float], image_size: Tuple[int, int] = None) -> str:
|
||||
"""
|
||||
Detect bounding box format based on values and optionally image dimensions.
|
||||
Common formats:
|
||||
- [x_min, y_min, x_max, y_max] - XYXY (Pascal VOC)
|
||||
- [x_min, y_min, width, height] - XYWH (COCO)
|
||||
- [x_center, y_center, width, height] - CXCYWH (YOLO normalized)
|
||||
"""
|
||||
if len(bbox) != 4:
|
||||
return "unknown (not 4 values)"
|
||||
|
||||
a, b, c, d = bbox
|
||||
|
||||
is_normalized = all(0 <= v <= 1 for v in bbox)
|
||||
|
||||
if c < a or d < b:
|
||||
if is_normalized:
|
||||
return "xywh_normalized"
|
||||
return "xywh (COCO style)"
|
||||
|
||||
# c > a and d > b — ambiguous between xyxy and xywh.
|
||||
# Use image dimensions to disambiguate when available.
|
||||
if image_size is not None:
|
||||
img_w, img_h = image_size
|
||||
# If interpreting as xywh, right edge = a + c; if that overshoots the
|
||||
# image while c alone fits, the format is more likely xyxy.
|
||||
xywh_exceeds = (a + c > img_w * 1.05) or (b + d > img_h * 1.05)
|
||||
xyxy_exceeds = (c > img_w * 1.05) or (d > img_h * 1.05)
|
||||
if xywh_exceeds and not xyxy_exceeds:
|
||||
return "xyxy (Pascal VOC style)"
|
||||
if xyxy_exceeds and not xywh_exceeds:
|
||||
return "xywh (COCO style)"
|
||||
|
||||
if is_normalized:
|
||||
return "xyxy_normalized"
|
||||
return "xyxy (Pascal VOC style)"
|
||||
|
||||
|
||||
def _extract_image_size(row: Dict) -> Tuple[int, int] | None:
|
||||
"""Try to extract (width, height) from the image column returned by Datasets Server."""
|
||||
for col in ("image", "img", "picture", "photo"):
|
||||
img = row.get(col)
|
||||
if isinstance(img, dict):
|
||||
w = img.get("width")
|
||||
h = img.get("height")
|
||||
if isinstance(w, (int, float)) and isinstance(h, (int, float)):
|
||||
return (int(w), int(h))
|
||||
return None
|
||||
|
||||
|
||||
def analyze_annotations(sample_rows: List[Dict], annotation_cols: List[str]) -> Dict[str, Any]:
|
||||
"""Analyze annotation structure from sample rows"""
|
||||
if not annotation_cols:
|
||||
return {"found": False}
|
||||
|
||||
annotation_col = annotation_cols[0]
|
||||
annotations_info = {
|
||||
"found": True,
|
||||
"column": annotation_col,
|
||||
"sample_structures": [],
|
||||
"bbox_formats": [],
|
||||
"categories_found": [],
|
||||
"avg_objects_per_image": 0,
|
||||
"max_objects": 0,
|
||||
"min_objects": float('inf'),
|
||||
}
|
||||
|
||||
total_objects = 0
|
||||
valid_samples = 0
|
||||
|
||||
for row in sample_rows:
|
||||
ann = row["row"].get(annotation_col)
|
||||
if not ann:
|
||||
continue
|
||||
|
||||
valid_samples += 1
|
||||
image_size = _extract_image_size(row["row"])
|
||||
|
||||
# Check if it's a list of annotations or a dict
|
||||
if isinstance(ann, dict):
|
||||
# COCO-style or structured annotation
|
||||
sample_structure = {
|
||||
"type": "dict",
|
||||
"keys": list(ann.keys())
|
||||
}
|
||||
|
||||
# Check for bounding boxes
|
||||
if "bbox" in ann or "bboxes" in ann:
|
||||
bbox_key = "bbox" if "bbox" in ann else "bboxes"
|
||||
bboxes = ann[bbox_key]
|
||||
if isinstance(bboxes, list) and len(bboxes) > 0:
|
||||
if isinstance(bboxes[0], list):
|
||||
# Multiple bboxes
|
||||
num_objects = len(bboxes)
|
||||
total_objects += num_objects
|
||||
annotations_info["max_objects"] = max(annotations_info["max_objects"], num_objects)
|
||||
annotations_info["min_objects"] = min(annotations_info["min_objects"], num_objects)
|
||||
|
||||
# Analyze first bbox format
|
||||
bbox_format = detect_bbox_format(bboxes[0], image_size)
|
||||
annotations_info["bbox_formats"].append(bbox_format)
|
||||
else:
|
||||
# Single bbox
|
||||
total_objects += 1
|
||||
annotations_info["max_objects"] = max(annotations_info["max_objects"], 1)
|
||||
annotations_info["min_objects"] = min(annotations_info["min_objects"], 1)
|
||||
bbox_format = detect_bbox_format(bboxes, image_size)
|
||||
annotations_info["bbox_formats"].append(bbox_format)
|
||||
|
||||
# Check for categories/classes
|
||||
for key in ["category", "categories", "label", "labels", "class", "classes", "category_id"]:
|
||||
if key in ann:
|
||||
cats = ann[key]
|
||||
if isinstance(cats, list):
|
||||
annotations_info["categories_found"].extend([str(c) for c in cats])
|
||||
else:
|
||||
annotations_info["categories_found"].append(str(cats))
|
||||
|
||||
annotations_info["sample_structures"].append(sample_structure)
|
||||
|
||||
elif isinstance(ann, list):
|
||||
# List of annotation dicts
|
||||
sample_structure = {
|
||||
"type": "list",
|
||||
"length": len(ann),
|
||||
"item_type": type(ann[0]).__name__ if ann else None
|
||||
}
|
||||
|
||||
if ann and isinstance(ann[0], dict):
|
||||
sample_structure["item_keys"] = list(ann[0].keys())
|
||||
|
||||
# Count objects
|
||||
num_objects = len(ann)
|
||||
total_objects += num_objects
|
||||
annotations_info["max_objects"] = max(annotations_info["max_objects"], num_objects)
|
||||
annotations_info["min_objects"] = min(annotations_info["min_objects"], num_objects)
|
||||
|
||||
# Check first annotation
|
||||
first_ann = ann[0]
|
||||
if "bbox" in first_ann:
|
||||
bbox_format = detect_bbox_format(first_ann["bbox"], image_size)
|
||||
annotations_info["bbox_formats"].append(bbox_format)
|
||||
|
||||
# Check for categories
|
||||
for key in ["category", "label", "class", "category_id"]:
|
||||
if key in first_ann:
|
||||
for item in ann:
|
||||
if key in item:
|
||||
annotations_info["categories_found"].append(str(item[key]))
|
||||
|
||||
annotations_info["sample_structures"].append(sample_structure)
|
||||
|
||||
if valid_samples > 0:
|
||||
annotations_info["avg_objects_per_image"] = round(total_objects / valid_samples, 2)
|
||||
|
||||
if annotations_info["min_objects"] == float('inf'):
|
||||
annotations_info["min_objects"] = 0
|
||||
|
||||
# Get unique categories
|
||||
annotations_info["categories_found"] = list(set(annotations_info["categories_found"]))
|
||||
annotations_info["num_classes"] = len(annotations_info["categories_found"])
|
||||
|
||||
# Get most common bbox format
|
||||
if annotations_info["bbox_formats"]:
|
||||
from collections import Counter
|
||||
format_counts = Counter(annotations_info["bbox_formats"])
|
||||
annotations_info["primary_bbox_format"] = format_counts.most_common(1)[0][0]
|
||||
|
||||
return annotations_info
|
||||
|
||||
|
||||
def check_image_classification_compatibility(columns: List[str], sample_rows: List[Dict], features: List[Dict]) -> Dict[str, Any]:
|
||||
"""Check image classification dataset compatibility"""
|
||||
|
||||
image_cols = find_columns(columns, ["image", "img", "picture", "photo"])
|
||||
has_image = len(image_cols) > 0
|
||||
|
||||
label_cols = find_columns(columns, ["label", "labels", "class", "fine_label", "coarse_label"])
|
||||
has_label = len(label_cols) > 0
|
||||
|
||||
label_info: Dict[str, Any] = {"found": has_label}
|
||||
|
||||
if has_label:
|
||||
label_col = label_cols[0]
|
||||
label_info["column"] = label_col
|
||||
|
||||
# Detect whether label is ClassLabel (int with names) or plain int/string
|
||||
for f in features:
|
||||
if f.get("name") == label_col:
|
||||
ftype = f.get("type", "")
|
||||
if isinstance(ftype, dict) and ftype.get("_type") == "ClassLabel":
|
||||
label_info["type"] = "ClassLabel"
|
||||
names = ftype.get("names", [])
|
||||
label_info["num_classes"] = len(names)
|
||||
label_info["class_names"] = names[:20]
|
||||
if len(names) > 20:
|
||||
label_info["class_names_truncated"] = True
|
||||
elif isinstance(ftype, dict) and ftype.get("dtype") in ("int64", "int32", "int8"):
|
||||
label_info["type"] = "int"
|
||||
elif isinstance(ftype, dict) and ftype.get("dtype") == "string":
|
||||
label_info["type"] = "string"
|
||||
break
|
||||
|
||||
# Discover unique labels from samples if ClassLabel info wasn't in features
|
||||
if "num_classes" not in label_info:
|
||||
unique = set()
|
||||
for row in sample_rows:
|
||||
val = row["row"].get(label_col)
|
||||
if val is not None:
|
||||
unique.add(val)
|
||||
label_info["sample_unique_labels"] = sorted(unique, key=str)[:20]
|
||||
label_info["sample_unique_count"] = len(unique)
|
||||
|
||||
ready = has_image and has_label
|
||||
return {
|
||||
"ready": ready,
|
||||
"has_image": has_image,
|
||||
"image_columns": image_cols,
|
||||
"has_label": has_label,
|
||||
"label_columns": label_cols,
|
||||
"label_info": label_info,
|
||||
}
|
||||
|
||||
|
||||
def check_object_detection_compatibility(columns: List[str], sample_rows: List[Dict]) -> Dict[str, Any]:
|
||||
"""Check object detection dataset compatibility"""
|
||||
|
||||
# Find image column
|
||||
image_cols = find_columns(columns, ["image", "img", "picture", "photo"])
|
||||
has_image = len(image_cols) > 0
|
||||
|
||||
# Find annotation columns
|
||||
annotation_cols = find_columns(columns, ["objects", "annotations", "ann", "bbox", "bboxes", "detection"])
|
||||
has_annotations = len(annotation_cols) > 0
|
||||
|
||||
# Analyze annotations
|
||||
annotations_info = analyze_annotations(sample_rows, annotation_cols) if has_annotations else {"found": False}
|
||||
|
||||
# Check for separate bbox and category columns
|
||||
bbox_cols = find_columns(columns, ["bbox", "bboxes", "boxes"])
|
||||
category_cols = find_columns(columns, ["category", "label", "class", "categories", "labels", "classes"])
|
||||
|
||||
# Determine readiness
|
||||
ready = has_image and (has_annotations or (len(bbox_cols) > 0 and len(category_cols) > 0))
|
||||
|
||||
return {
|
||||
"ready": ready,
|
||||
"has_image": has_image,
|
||||
"image_columns": image_cols,
|
||||
"has_annotations": has_annotations,
|
||||
"annotation_columns": annotation_cols,
|
||||
"separate_bbox_columns": bbox_cols,
|
||||
"separate_category_columns": category_cols,
|
||||
"annotations_info": annotations_info,
|
||||
}
|
||||
|
||||
|
||||
def check_sam_segmentation_compatibility(columns: List[str], sample_rows: List[Dict], features: List[Dict]) -> Dict[str, Any]:
|
||||
"""Check SAM/SAM2 segmentation dataset compatibility.
|
||||
|
||||
A valid SAM segmentation dataset needs:
|
||||
- An image column
|
||||
- A mask column (binary ground-truth segmentation mask)
|
||||
- A prompt: either a bbox prompt or point prompt (in a JSON prompt column, or dedicated columns)
|
||||
"""
|
||||
|
||||
image_cols = find_columns(columns, ["image", "img", "picture", "photo"])
|
||||
has_image = len(image_cols) > 0
|
||||
|
||||
mask_cols = find_columns(columns, ["mask", "segmentation", "alpha", "matte"])
|
||||
has_mask = len(mask_cols) > 0
|
||||
|
||||
prompt_cols = find_columns(columns, ["prompt"])
|
||||
bbox_cols = [c for c in columns if c in ("bbox", "bboxes", "box", "boxes")]
|
||||
point_cols = [c for c in columns if c in ("point", "points", "input_point", "input_points")]
|
||||
|
||||
prompt_info: Dict[str, Any] = {
|
||||
"has_prompt": False,
|
||||
"prompt_type": None,
|
||||
"source": None,
|
||||
"bbox_valid": None,
|
||||
}
|
||||
|
||||
# Try JSON prompt column first
|
||||
if prompt_cols:
|
||||
for row in sample_rows:
|
||||
raw = row["row"].get(prompt_cols[0])
|
||||
if raw is None:
|
||||
continue
|
||||
parsed = raw if isinstance(raw, dict) else _try_json(raw)
|
||||
if parsed is None:
|
||||
continue
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
if "bbox" in parsed or "box" in parsed:
|
||||
prompt_info["has_prompt"] = True
|
||||
prompt_info["prompt_type"] = "bbox"
|
||||
prompt_info["source"] = f"JSON column '{prompt_cols[0]}'"
|
||||
bbox = parsed.get("bbox") or parsed.get("box")
|
||||
prompt_info["bbox_valid"] = _validate_bbox(bbox, _extract_image_size(row["row"]))
|
||||
break
|
||||
elif "point" in parsed or "points" in parsed:
|
||||
prompt_info["has_prompt"] = True
|
||||
prompt_info["prompt_type"] = "point"
|
||||
prompt_info["source"] = f"JSON column '{prompt_cols[0]}'"
|
||||
break
|
||||
|
||||
if not prompt_info["has_prompt"] and bbox_cols:
|
||||
prompt_info["has_prompt"] = True
|
||||
prompt_info["prompt_type"] = "bbox"
|
||||
prompt_info["source"] = f"column '{bbox_cols[0]}'"
|
||||
for row in sample_rows:
|
||||
bbox = row["row"].get(bbox_cols[0])
|
||||
if bbox is not None:
|
||||
prompt_info["bbox_valid"] = _validate_bbox(bbox, _extract_image_size(row["row"]))
|
||||
break
|
||||
|
||||
if not prompt_info["has_prompt"] and point_cols:
|
||||
prompt_info["has_prompt"] = True
|
||||
prompt_info["prompt_type"] = "point"
|
||||
prompt_info["source"] = f"column '{point_cols[0]}'"
|
||||
|
||||
ready = has_image and has_mask and prompt_info["has_prompt"]
|
||||
|
||||
return {
|
||||
"ready": ready,
|
||||
"has_image": has_image,
|
||||
"image_columns": image_cols,
|
||||
"has_mask": has_mask,
|
||||
"mask_columns": mask_cols,
|
||||
"prompt_columns": prompt_cols,
|
||||
"bbox_columns": bbox_cols,
|
||||
"point_columns": point_cols,
|
||||
"prompt_info": prompt_info,
|
||||
}
|
||||
|
||||
|
||||
def _try_json(value) -> Any:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _validate_bbox(bbox, image_size=None) -> Dict[str, Any]:
|
||||
"""Validate a single bounding box and return diagnostics."""
|
||||
result: Dict[str, Any] = {"valid": False}
|
||||
if not isinstance(bbox, (list, tuple)):
|
||||
result["error"] = "bbox is not a list"
|
||||
return result
|
||||
if len(bbox) != 4:
|
||||
result["error"] = f"expected 4 values, got {len(bbox)}"
|
||||
return result
|
||||
try:
|
||||
vals = [float(v) for v in bbox]
|
||||
except (TypeError, ValueError):
|
||||
result["error"] = "non-numeric values"
|
||||
return result
|
||||
|
||||
if not all(math.isfinite(v) for v in vals):
|
||||
result["error"] = "contains non-finite values"
|
||||
return result
|
||||
|
||||
x0, y0, x1, y1 = vals
|
||||
if x1 <= x0 or y1 <= y0:
|
||||
if vals[2] > 0 and vals[3] > 0:
|
||||
result["format_hint"] = "likely xywh"
|
||||
else:
|
||||
result["error"] = "degenerate bbox (zero or negative area)"
|
||||
return result
|
||||
else:
|
||||
result["format_hint"] = "likely xyxy"
|
||||
|
||||
if image_size is not None:
|
||||
img_w, img_h = image_size
|
||||
if any(v > max(img_w, img_h) * 1.5 for v in vals):
|
||||
result["warning"] = "coordinates exceed image bounds"
|
||||
|
||||
result["valid"] = True
|
||||
result["values"] = vals
|
||||
return result
|
||||
|
||||
|
||||
def generate_mapping_code(info: Dict[str, Any]) -> str:
|
||||
"""Generate mapping code if needed"""
|
||||
if info["ready"]:
|
||||
ann_info = info["annotations_info"]
|
||||
if not ann_info.get("found"):
|
||||
return None
|
||||
|
||||
# Check if format conversion is needed
|
||||
ann_col = ann_info.get("column")
|
||||
bbox_format = ann_info.get("primary_bbox_format", "unknown")
|
||||
|
||||
if "coco" in bbox_format.lower() or "xywh" in bbox_format.lower():
|
||||
# Already COCO format
|
||||
return f"""# Dataset appears to be in COCO format (xywh)
|
||||
# Image column: {info['image_columns'][0] if info['image_columns'] else 'image'}
|
||||
# Annotation column: {ann_col}
|
||||
# Use directly with transformers object detection models"""
|
||||
elif "xyxy" in bbox_format.lower():
|
||||
# Need to convert from XYXY to XYWH
|
||||
return f"""# Convert from XYXY (Pascal VOC) to XYWH (COCO) format
|
||||
def convert_to_coco_format(example):
|
||||
annotations = example['{ann_col}']
|
||||
if isinstance(annotations, list):
|
||||
for ann in annotations:
|
||||
if 'bbox' in ann:
|
||||
x_min, y_min, x_max, y_max = ann['bbox']
|
||||
ann['bbox'] = [x_min, y_min, x_max - x_min, y_max - y_min]
|
||||
elif isinstance(annotations, dict) and 'bbox' in annotations:
|
||||
bbox = annotations['bbox']
|
||||
if isinstance(bbox, list) and len(bbox) > 0 and isinstance(bbox[0], list):
|
||||
for i, box in enumerate(bbox):
|
||||
x_min, y_min, x_max, y_max = box
|
||||
bbox[i] = [x_min, y_min, x_max - x_min, y_max - y_min]
|
||||
return example
|
||||
|
||||
dataset = dataset.map(convert_to_coco_format)"""
|
||||
|
||||
elif not info["ready"]:
|
||||
# Need to create annotations structure
|
||||
if info["separate_bbox_columns"] and info["separate_category_columns"]:
|
||||
bbox_col = info["separate_bbox_columns"][0]
|
||||
cat_col = info["separate_category_columns"][0]
|
||||
|
||||
return f"""# Combine separate bbox and category columns
|
||||
def create_annotations(example):
|
||||
bboxes = example['{bbox_col}']
|
||||
categories = example['{cat_col}']
|
||||
|
||||
if not isinstance(bboxes, list):
|
||||
bboxes = [bboxes]
|
||||
if not isinstance(categories, list):
|
||||
categories = [categories]
|
||||
|
||||
annotations = []
|
||||
for bbox, cat in zip(bboxes, categories):
|
||||
annotations.append({{'bbox': bbox, 'category': cat}})
|
||||
|
||||
example['objects'] = annotations
|
||||
return example
|
||||
|
||||
dataset = dataset.map(create_annotations)"""
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def format_value_preview(value: Any, max_chars: int) -> str:
|
||||
"""Format value for preview"""
|
||||
if value is None:
|
||||
return "None"
|
||||
elif isinstance(value, str):
|
||||
return value[:max_chars] + ("..." if len(value) > max_chars else "")
|
||||
elif isinstance(value, dict):
|
||||
keys = list(value.keys())
|
||||
return f"{{dict with {len(keys)} keys: {', '.join(keys[:5])}}}"
|
||||
elif isinstance(value, list):
|
||||
if len(value) == 0:
|
||||
return "[]"
|
||||
elif isinstance(value[0], dict):
|
||||
return f"[{len(value)} items] First item keys: {list(value[0].keys())}"
|
||||
elif isinstance(value[0], list):
|
||||
return f"[{len(value)} items] First item: {value[0]}"
|
||||
else:
|
||||
preview = str(value)
|
||||
return preview[:max_chars] + ("..." if len(preview) > max_chars else "")
|
||||
else:
|
||||
preview = str(value)
|
||||
return preview[:max_chars] + ("..." if len(preview) > max_chars else "")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
print(f"Fetching dataset info via Datasets Server API...")
|
||||
|
||||
try:
|
||||
# Get splits info
|
||||
splits_data = get_splits(args.dataset)
|
||||
if not splits_data or "splits" not in splits_data:
|
||||
print(f"ERROR: Could not fetch splits for dataset '{args.dataset}'")
|
||||
print(f" Dataset may not exist or is not accessible via Datasets Server API")
|
||||
sys.exit(1)
|
||||
|
||||
# Find the right config
|
||||
available_configs = set()
|
||||
split_found = False
|
||||
config_to_use = args.config
|
||||
|
||||
for split_info in splits_data["splits"]:
|
||||
available_configs.add(split_info["config"])
|
||||
if split_info["config"] == args.config and split_info["split"] == args.split:
|
||||
split_found = True
|
||||
|
||||
# If default config not found, try first available
|
||||
if not split_found and available_configs:
|
||||
config_to_use = list(available_configs)[0]
|
||||
print(f"Config '{args.config}' not found, trying '{config_to_use}'...")
|
||||
|
||||
# Get rows
|
||||
rows_data = get_rows(args.dataset, config_to_use, args.split, offset=0, length=args.samples)
|
||||
|
||||
if not rows_data or "rows" not in rows_data:
|
||||
print(f"ERROR: Could not fetch rows for dataset '{args.dataset}'")
|
||||
print(f" Split '{args.split}' may not exist")
|
||||
print(f" Available configs: {', '.join(sorted(available_configs))}")
|
||||
sys.exit(1)
|
||||
|
||||
rows = rows_data["rows"]
|
||||
if not rows:
|
||||
print(f"ERROR: No rows found in split '{args.split}'")
|
||||
sys.exit(1)
|
||||
|
||||
# Extract column info from first row
|
||||
first_row = rows[0]["row"]
|
||||
columns = list(first_row.keys())
|
||||
features = rows_data.get("features", [])
|
||||
|
||||
# Get total count if available
|
||||
total_examples = "Unknown"
|
||||
for split_info in splits_data["splits"]:
|
||||
if split_info["config"] == config_to_use and split_info["split"] == args.split:
|
||||
total_examples = f"{split_info.get('num_examples', 'Unknown'):,}" if isinstance(split_info.get('num_examples'), int) else "Unknown"
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
# Run compatibility checks
|
||||
od_info = check_object_detection_compatibility(columns, rows)
|
||||
ic_info = check_image_classification_compatibility(columns, rows, features)
|
||||
sam_info = check_sam_segmentation_compatibility(columns, rows, features)
|
||||
|
||||
# JSON output mode
|
||||
if args.json_output:
|
||||
result = {
|
||||
"dataset": args.dataset,
|
||||
"config": config_to_use,
|
||||
"split": args.split,
|
||||
"total_examples": total_examples,
|
||||
"columns": columns,
|
||||
"features": [{"name": f["name"], "type": f["type"]} for f in features] if features else [],
|
||||
"object_detection_compatibility": od_info,
|
||||
"image_classification_compatibility": ic_info,
|
||||
"sam_segmentation_compatibility": sam_info,
|
||||
}
|
||||
print(json.dumps(result, indent=2))
|
||||
sys.exit(0)
|
||||
|
||||
# Human-readable output optimized for LLM parsing
|
||||
print("=" * 80)
|
||||
print(f"VISION DATASET INSPECTION")
|
||||
print("=" * 80)
|
||||
|
||||
print(f"\nDataset: {args.dataset}")
|
||||
print(f"Config: {config_to_use}")
|
||||
print(f"Split: {args.split}")
|
||||
print(f"Total examples: {total_examples}")
|
||||
print(f"Samples fetched: {len(rows)}")
|
||||
|
||||
print(f"\n{'COLUMNS':-<80}")
|
||||
if features:
|
||||
for feature in features:
|
||||
print(f" {feature['name']}: {feature['type']}")
|
||||
else:
|
||||
for col in columns:
|
||||
print(f" {col}: (type info not available)")
|
||||
|
||||
print(f"\n{'EXAMPLE DATA':-<80}")
|
||||
example = first_row
|
||||
for col in columns:
|
||||
value = example.get(col)
|
||||
display = format_value_preview(value, args.preview)
|
||||
print(f"\n{col}:")
|
||||
print(f" {display}")
|
||||
|
||||
# --- Image Classification ---
|
||||
print(f"\n{'IMAGE CLASSIFICATION COMPATIBILITY':-<80}")
|
||||
print(f"\n[STATUS] {'✓ READY' if ic_info['ready'] else '✗ NOT COMPATIBLE'}")
|
||||
|
||||
print(f"\nImage Column:")
|
||||
if ic_info["has_image"]:
|
||||
print(f" ✓ Found: {', '.join(ic_info['image_columns'])}")
|
||||
else:
|
||||
print(f" ✗ No image column detected")
|
||||
|
||||
print(f"\nLabel Column:")
|
||||
if ic_info["has_label"]:
|
||||
print(f" ✓ Found: {', '.join(ic_info['label_columns'])}")
|
||||
li = ic_info["label_info"]
|
||||
if li.get("type"):
|
||||
print(f" • Type: {li['type']}")
|
||||
if li.get("num_classes"):
|
||||
print(f" • Number of Classes: {li['num_classes']}")
|
||||
if li.get("class_names"):
|
||||
names = li["class_names"]
|
||||
display = ", ".join(str(n) for n in names[:10])
|
||||
if len(names) > 10:
|
||||
display += f" ... ({li['num_classes']} total)"
|
||||
print(f" • Classes: {display}")
|
||||
elif li.get("sample_unique_labels"):
|
||||
labels = li["sample_unique_labels"]
|
||||
display = ", ".join(str(l) for l in labels[:10])
|
||||
if li.get("sample_unique_count", 0) > 10:
|
||||
display += f" ... ({li['sample_unique_count']}+ from sample)"
|
||||
print(f" • Sample labels: {display}")
|
||||
else:
|
||||
print(f" ✗ No label column detected")
|
||||
print(f" Expected column names: 'label', 'labels', 'class', 'fine_label'")
|
||||
|
||||
if ic_info["ready"]:
|
||||
lc = ic_info["label_info"].get("column", "label")
|
||||
print(f"\n Use with: scripts/image_classification_training.py")
|
||||
print(f" --image_column_name {ic_info['image_columns'][0]} --label_column_name {lc}")
|
||||
|
||||
# --- Object Detection ---
|
||||
print(f"\n{'OBJECT DETECTION COMPATIBILITY':-<80}")
|
||||
print(f"\n[STATUS] {'✓ READY' if od_info['ready'] else '✗ NOT COMPATIBLE'}")
|
||||
|
||||
print(f"\nImage Column:")
|
||||
if od_info["has_image"]:
|
||||
print(f" ✓ Found: {', '.join(od_info['image_columns'])}")
|
||||
else:
|
||||
print(f" ✗ No image column detected")
|
||||
print(f" Expected column names: 'image', 'img', 'picture', 'photo'")
|
||||
|
||||
print(f"\nAnnotations:")
|
||||
if od_info["has_annotations"]:
|
||||
print(f" ✓ Found: {', '.join(od_info['annotation_columns'])}")
|
||||
ann_info = od_info["annotations_info"]
|
||||
if ann_info.get("found"):
|
||||
print(f"\n Annotation Details:")
|
||||
print(f" • Column: {ann_info['column']}")
|
||||
if ann_info.get("primary_bbox_format"):
|
||||
print(f" • BBox Format: {ann_info['primary_bbox_format']}")
|
||||
if ann_info.get("num_classes", 0) > 0:
|
||||
print(f" • Number of Classes: {ann_info['num_classes']}")
|
||||
print(f" • Classes: {', '.join(ann_info['categories_found'][:10])}")
|
||||
if len(ann_info['categories_found']) > 10:
|
||||
print(f" (showing first 10 of {len(ann_info['categories_found'])})")
|
||||
print(f" • Avg Objects/Image: {ann_info['avg_objects_per_image']}")
|
||||
print(f" • Min Objects: {ann_info['min_objects']}")
|
||||
print(f" • Max Objects: {ann_info['max_objects']}")
|
||||
elif od_info["separate_bbox_columns"] and od_info["separate_category_columns"]:
|
||||
print(f" ⚠ Separate bbox and category columns found:")
|
||||
print(f" BBox columns: {', '.join(od_info['separate_bbox_columns'])}")
|
||||
print(f" Category columns: {', '.join(od_info['separate_category_columns'])}")
|
||||
print(f" Action: These need to be combined (see mapping code below)")
|
||||
else:
|
||||
print(f" ✗ No annotation columns detected")
|
||||
print(f" Expected: 'objects', 'annotations', 'bbox'/'bboxes' + 'category'/'label'")
|
||||
|
||||
# --- SAM Segmentation ---
|
||||
print(f"\n{'SAM SEGMENTATION COMPATIBILITY':-<80}")
|
||||
print(f"\n[STATUS] {'✓ READY' if sam_info['ready'] else '✗ NOT COMPATIBLE'}")
|
||||
|
||||
print(f"\nImage Column:")
|
||||
if sam_info["has_image"]:
|
||||
print(f" ✓ Found: {', '.join(sam_info['image_columns'])}")
|
||||
else:
|
||||
print(f" ✗ No image column detected")
|
||||
|
||||
print(f"\nMask Column:")
|
||||
if sam_info["has_mask"]:
|
||||
print(f" ✓ Found: {', '.join(sam_info['mask_columns'])}")
|
||||
else:
|
||||
print(f" ✗ No mask column detected")
|
||||
print(f" Expected column names: 'mask', 'segmentation', 'alpha', 'matte'")
|
||||
|
||||
print(f"\nPrompt:")
|
||||
pi = sam_info["prompt_info"]
|
||||
if pi["has_prompt"]:
|
||||
print(f" ✓ Type: {pi['prompt_type']} (from {pi['source']})")
|
||||
if pi.get("bbox_valid"):
|
||||
bv = pi["bbox_valid"]
|
||||
if bv["valid"]:
|
||||
print(f" • BBox values: {bv.get('values')}")
|
||||
if bv.get("format_hint"):
|
||||
print(f" • Format: {bv['format_hint']}")
|
||||
if bv.get("warning"):
|
||||
print(f" ⚠ {bv['warning']}")
|
||||
else:
|
||||
print(f" ✗ Invalid bbox: {bv.get('error', 'unknown error')}")
|
||||
else:
|
||||
print(f" ✗ No prompt detected")
|
||||
print(f" Expected: 'prompt' column (JSON with bbox/point), or 'bbox'/'point' column")
|
||||
|
||||
if sam_info["ready"]:
|
||||
pc = sam_info["prompt_columns"][0] if sam_info["prompt_columns"] else None
|
||||
args_hint = f"--prompt_type {pi['prompt_type']}"
|
||||
if pc:
|
||||
args_hint += f" --prompt_column_name {pc}"
|
||||
print(f"\n Use with: scripts/sam_segmentation_training.py")
|
||||
print(f" {args_hint}")
|
||||
|
||||
# Mapping code (OD only)
|
||||
mapping_code = generate_mapping_code(od_info)
|
||||
|
||||
if mapping_code:
|
||||
print(f"\n{'OD PREPROCESSING CODE':-<80}")
|
||||
print(mapping_code)
|
||||
elif od_info["ready"]:
|
||||
print(f"\n ✓ No OD preprocessing needed.")
|
||||
|
||||
# --- Summary ---
|
||||
print(f"\n{'SUMMARY':-<80}")
|
||||
if ic_info["ready"]:
|
||||
num_cls = ic_info["label_info"].get("num_classes") or ic_info["label_info"].get("sample_unique_count", "?")
|
||||
print(f"✓ Image Classification: READY ({num_cls} classes)")
|
||||
else:
|
||||
print(f"✗ Image Classification: not compatible")
|
||||
|
||||
if od_info["ready"]:
|
||||
ann_info = od_info["annotations_info"]
|
||||
fmt = ann_info.get("primary_bbox_format", "")
|
||||
cls = ann_info.get("num_classes", "?")
|
||||
print(f"✓ Object Detection: READY ({cls} classes, {fmt})")
|
||||
else:
|
||||
print(f"✗ Object Detection: not compatible")
|
||||
|
||||
if sam_info["ready"]:
|
||||
print(f"✓ SAM Segmentation: READY (prompt: {pi['prompt_type']})")
|
||||
else:
|
||||
print(f"✗ SAM Segmentation: not compatible")
|
||||
|
||||
print(f"\nNote: Used Datasets Server API (instant, no download required)")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
print(f"ERROR: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
@@ -0,0 +1,217 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = []
|
||||
# ///
|
||||
"""
|
||||
Estimate training time and cost for vision model training jobs on Hugging Face Jobs.
|
||||
|
||||
Usage:
|
||||
uv run estimate_cost.py --model ustc-community/dfine-small-coco --dataset cppe-5 --hardware t4-small
|
||||
uv run estimate_cost.py --model PekingU/rtdetr_v2_r50vd --dataset-size 5000 --hardware t4-small --epochs 30
|
||||
uv run estimate_cost.py --model google/vit-base-patch16-224-in21k --dataset ethz/food101 --hardware t4-small --epochs 3
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
HARDWARE_COSTS = {
|
||||
"t4-small": 0.40,
|
||||
"t4-medium": 0.60,
|
||||
"l4x1": 0.80,
|
||||
"l4x4": 3.80,
|
||||
"a10g-small": 1.00,
|
||||
"a10g-large": 1.50,
|
||||
"a10g-largex2": 3.00,
|
||||
"a10g-largex4": 5.00,
|
||||
"l40sx1": 1.80,
|
||||
"l40sx4": 8.30,
|
||||
"a100-large": 2.50,
|
||||
"a100x4": 10.00,
|
||||
}
|
||||
|
||||
# Vision model sizes in millions of parameters
|
||||
MODEL_PARAMS_M = {
|
||||
# Object detection
|
||||
"dfine-small": 10.4,
|
||||
"dfine-large": 31.4,
|
||||
"dfine-xlarge": 63.5,
|
||||
"rtdetr_v2_r18vd": 20.2,
|
||||
"rtdetr_v2_r50vd": 43.0,
|
||||
"rtdetr_v2_r101vd": 76.0,
|
||||
"detr-resnet-50": 41.3,
|
||||
"detr-resnet-101": 60.2,
|
||||
"yolos-small": 30.7,
|
||||
"yolos-tiny": 6.5,
|
||||
# Image classification
|
||||
"mobilenetv3_small": 2.5,
|
||||
"mobilevit_s": 5.6,
|
||||
"resnet50": 25.6,
|
||||
"vit_base_patch16": 86.6,
|
||||
# SAM / SAM2 segmentation
|
||||
"sam-vit-base": 93.7,
|
||||
"sam-vit-large": 312.3,
|
||||
"sam-vit-huge": 641.1,
|
||||
"sam2.1-hiera-tiny": 38.9,
|
||||
"sam2.1-hiera-small": 46.0,
|
||||
"sam2.1-hiera-base-plus": 80.8,
|
||||
"sam2.1-hiera-large": 224.4,
|
||||
}
|
||||
|
||||
KNOWN_DATASETS = {
|
||||
# Object detection
|
||||
"cppe-5": 1000,
|
||||
"merve/license-plate": 6180,
|
||||
# Image classification
|
||||
"ethz/food101": 75750,
|
||||
# SAM segmentation
|
||||
"merve/MicroMat-mini": 240,
|
||||
}
|
||||
|
||||
|
||||
def extract_model_params(model_name: str) -> float:
|
||||
"""Extract model size in millions of parameters from the model name."""
|
||||
name_lower = model_name.lower()
|
||||
for key, params in MODEL_PARAMS_M.items():
|
||||
if key.lower() in name_lower:
|
||||
return params
|
||||
return 30.0 # reasonable default for vision models
|
||||
|
||||
|
||||
def estimate_training_time(model_params_m: float, dataset_size: int, epochs: int,
|
||||
image_size: int, batch_size: int, hardware: str) -> float:
|
||||
"""Estimate training time in hours for vision model training."""
|
||||
# Steps per epoch
|
||||
steps_per_epoch = dataset_size / batch_size
|
||||
# empirical calibration values
|
||||
base_secs_per_step = 0.8
|
||||
model_factor = (model_params_m / 30.0) ** 0.6
|
||||
image_factor = (image_size / 640.0) ** 2
|
||||
|
||||
|
||||
batch_factor = (batch_size / 8.0) ** 0.7
|
||||
|
||||
secs_per_step = base_secs_per_step * model_factor * image_factor * batch_factor
|
||||
|
||||
hardware_multipliers = {
|
||||
"t4-small": 2.0,
|
||||
"t4-medium": 2.0,
|
||||
"l4x1": 1.2,
|
||||
"l4x4": 0.5,
|
||||
"a10g-small": 1.0,
|
||||
"a10g-large": 1.0,
|
||||
"a10g-largex2": 0.6,
|
||||
"a10g-largex4": 0.4,
|
||||
"l40sx1": 0.7,
|
||||
"l40sx4": 0.25,
|
||||
"a100-large": 0.5,
|
||||
"a100x4": 0.2,
|
||||
}
|
||||
|
||||
multiplier = hardware_multipliers.get(hardware, 1.0)
|
||||
total_steps = steps_per_epoch * epochs
|
||||
total_secs = total_steps * secs_per_step * multiplier
|
||||
|
||||
# Add overhead: model loading (~2 min), eval per epoch (~10% of training), Hub push (~3 min)
|
||||
eval_overhead = total_secs * 0.10
|
||||
fixed_overhead = 5 * 60 # 5 minutes
|
||||
total_secs += eval_overhead + fixed_overhead
|
||||
|
||||
return total_secs / 3600
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Estimate training cost for vision model training jobs")
|
||||
parser.add_argument("--model", required=True,
|
||||
help="Model name (e.g., 'ustc-community/dfine-small-coco' or 'detr-resnet-50')")
|
||||
parser.add_argument("--dataset", default=None, help="Dataset name (for known size lookup)")
|
||||
parser.add_argument("--hardware", required=True, choices=HARDWARE_COSTS.keys(), help="Hardware flavor")
|
||||
parser.add_argument("--dataset-size", type=int, default=None,
|
||||
help="Number of training images (overrides dataset lookup)")
|
||||
parser.add_argument("--epochs", type=int, default=30, help="Number of training epochs (default: 30)")
|
||||
parser.add_argument("--image-size", type=int, default=640, help="Image square size in pixels (default: 640)")
|
||||
parser.add_argument("--batch-size", type=int, default=8, help="Per-device batch size (default: 8)")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
model_params = extract_model_params(args.model)
|
||||
print(f"Model: {args.model} (~{model_params:.1f}M parameters)")
|
||||
|
||||
if args.dataset_size:
|
||||
dataset_size = args.dataset_size
|
||||
elif args.dataset and args.dataset in KNOWN_DATASETS:
|
||||
dataset_size = KNOWN_DATASETS[args.dataset]
|
||||
elif args.dataset:
|
||||
print(f"Unknown dataset '{args.dataset}', defaulting to 1000 images.")
|
||||
print(f"Use --dataset-size to specify the exact count.")
|
||||
dataset_size = 1000
|
||||
else:
|
||||
dataset_size = 1000
|
||||
|
||||
print(f"Dataset: {args.dataset or 'custom'} (~{dataset_size} images)")
|
||||
print(f"Epochs: {args.epochs}")
|
||||
print(f"Image size: {args.image_size}px")
|
||||
print(f"Batch size: {args.batch_size}")
|
||||
print(f"Hardware: {args.hardware} (${HARDWARE_COSTS[args.hardware]:.2f}/hr)")
|
||||
print()
|
||||
|
||||
estimated_hours = estimate_training_time(
|
||||
model_params, dataset_size, args.epochs, args.image_size, args.batch_size, args.hardware
|
||||
)
|
||||
estimated_cost = estimated_hours * HARDWARE_COSTS[args.hardware]
|
||||
recommended_timeout = estimated_hours * 1.3 # 30% buffer
|
||||
|
||||
print(f"Estimated training time: {estimated_hours:.1f} hours")
|
||||
print(f"Estimated cost: ${estimated_cost:.2f}")
|
||||
print(f"Recommended timeout: {recommended_timeout:.1f}h (with 30% buffer)")
|
||||
print()
|
||||
|
||||
if estimated_hours > 6:
|
||||
print("Warning: Long training time. Consider:")
|
||||
print(" - Reducing epochs or image size")
|
||||
print(" - Using --max_train_samples for a test run first")
|
||||
print(" - Upgrading hardware")
|
||||
print()
|
||||
|
||||
if model_params > 50 and args.hardware in ("t4-small", "t4-medium"):
|
||||
print("Warning: Large model on T4. If you hit OOM:")
|
||||
print(" - Reduce batch size (try 4, then 2)")
|
||||
print(" - Reduce image size (try 480)")
|
||||
print(" - Upgrade to l4x1 or a10g-small")
|
||||
print()
|
||||
|
||||
timeout_str = f"{recommended_timeout:.0f}h"
|
||||
timeout_secs = int(recommended_timeout * 3600)
|
||||
print(f"Example job configuration (MCP tool):")
|
||||
print(f"""
|
||||
hf_jobs("uv", {{
|
||||
"script": "scripts/object_detection_training.py",
|
||||
"script_args": [
|
||||
"--model_name_or_path", "{args.model}",
|
||||
"--dataset_name", "{args.dataset or 'your-dataset'}",
|
||||
"--image_square_size", "{args.image_size}",
|
||||
"--num_train_epochs", "{args.epochs}",
|
||||
"--per_device_train_batch_size", "{args.batch_size}",
|
||||
"--push_to_hub", "--do_train", "--do_eval"
|
||||
],
|
||||
"flavor": "{args.hardware}",
|
||||
"timeout": "{timeout_str}",
|
||||
"secrets": {{"HF_TOKEN": "$HF_TOKEN"}}
|
||||
}})
|
||||
""")
|
||||
print(f"Example job configuration (Python API):")
|
||||
print(f"""
|
||||
api.run_uv_job(
|
||||
script="scripts/object_detection_training.py",
|
||||
script_args=[...],
|
||||
flavor="{args.hardware}",
|
||||
timeout={timeout_secs},
|
||||
secrets={{"HF_TOKEN": get_token()}},
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,383 @@
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers>=5.2.0",
|
||||
# "accelerate>=1.1.0",
|
||||
# "timm",
|
||||
# "datasets>=4.0",
|
||||
# "evaluate",
|
||||
# "scikit-learn",
|
||||
# "torchvision",
|
||||
# "trackio",
|
||||
# "huggingface_hub",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""Fine-tuning any Transformers or timm model supported by AutoModelForImageClassification using the Trainer API."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
Compose,
|
||||
Normalize,
|
||||
RandomHorizontalFlip,
|
||||
RandomResizedCrop,
|
||||
Resize,
|
||||
ToTensor,
|
||||
)
|
||||
|
||||
import trackio
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoImageProcessor,
|
||||
AutoModelForImageClassification,
|
||||
DefaultDataCollator,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer import EvalPrediction
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
check_min_version("4.57.0.dev0")
|
||||
require_version("datasets>=2.0.0")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
dataset_name: str = field(
|
||||
default="ethz/food101",
|
||||
metadata={"help": "Name of a dataset from the Hub."},
|
||||
)
|
||||
dataset_config_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The configuration name of the dataset to use (via the datasets library)."},
|
||||
)
|
||||
train_val_split: float | None = field(
|
||||
default=0.15,
|
||||
metadata={"help": "Fraction to split off of train for validation (used only when no validation split exists)."},
|
||||
)
|
||||
max_train_samples: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Truncate training set to this many samples (for debugging / quick tests)."},
|
||||
)
|
||||
max_eval_samples: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Truncate evaluation set to this many samples."},
|
||||
)
|
||||
image_column_name: str = field(
|
||||
default="image",
|
||||
metadata={"help": "The column name for images in the dataset."},
|
||||
)
|
||||
label_column_name: str = field(
|
||||
default="label",
|
||||
metadata={"help": "The column name for labels in the dataset."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: str = field(
|
||||
default="timm/mobilenetv3_small_100.lamb_in1k",
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."},
|
||||
)
|
||||
config_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Pretrained config name or path if not the same as model_name."},
|
||||
)
|
||||
cache_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store pretrained models downloaded from the Hub."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (branch, tag, or commit id)."},
|
||||
)
|
||||
image_processor_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Name or path of image processor config."},
|
||||
)
|
||||
ignore_mismatched_sizes: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Allow loading weights when num_labels differs from pretrained checkpoint."},
|
||||
)
|
||||
token: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token for private models / datasets."},
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to trust remote code from Hub repos."},
|
||||
)
|
||||
|
||||
|
||||
def build_transforms(image_processor, is_training: bool):
|
||||
"""Build torchvision transforms from the image processor's config."""
|
||||
if hasattr(image_processor, "size"):
|
||||
size = image_processor.size
|
||||
if "shortest_edge" in size:
|
||||
img_size = size["shortest_edge"]
|
||||
elif "height" in size and "width" in size:
|
||||
img_size = (size["height"], size["width"])
|
||||
else:
|
||||
img_size = 224
|
||||
else:
|
||||
img_size = 224
|
||||
|
||||
if hasattr(image_processor, "image_mean") and image_processor.image_mean:
|
||||
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
||||
else:
|
||||
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
if is_training:
|
||||
return Compose([
|
||||
RandomResizedCrop(img_size),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
else:
|
||||
if isinstance(img_size, int):
|
||||
resize_size = int(img_size / 0.875) # standard 87.5% center crop ratio
|
||||
else:
|
||||
resize_size = tuple(int(s / 0.875) for s in img_size)
|
||||
return Compose([
|
||||
Resize(resize_size),
|
||||
CenterCrop(img_size),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# --- Hub authentication ---
|
||||
from huggingface_hub import login
|
||||
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("hfjob")
|
||||
if hf_token:
|
||||
login(token=hf_token)
|
||||
training_args.hub_token = hf_token
|
||||
logger.info("Logged in to Hugging Face Hub")
|
||||
elif training_args.push_to_hub:
|
||||
logger.warning("HF_TOKEN not found in environment. Hub push will likely fail.")
|
||||
|
||||
# --- Trackio ---
|
||||
trackio.init(project=training_args.output_dir, name=training_args.run_name)
|
||||
|
||||
# --- Logging ---
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
if training_args.should_log:
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_process_index}, device: {training_args.device}, "
|
||||
f"n_gpu: {training_args.n_gpu}, distributed training: "
|
||||
f"{training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# --- Load dataset ---
|
||||
dataset = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
)
|
||||
|
||||
# --- Resolve label column ---
|
||||
label_col = data_args.label_column_name
|
||||
if label_col not in dataset["train"].column_names:
|
||||
candidates = [c for c in dataset["train"].column_names if c in ("label", "labels", "class", "fine_label")]
|
||||
if candidates:
|
||||
label_col = candidates[0]
|
||||
logger.info(f"Label column '{data_args.label_column_name}' not found, using '{label_col}'")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Label column '{data_args.label_column_name}' not found. "
|
||||
f"Available columns: {dataset['train'].column_names}"
|
||||
)
|
||||
|
||||
# --- Discover labels ---
|
||||
label_feature = dataset["train"].features[label_col]
|
||||
if hasattr(label_feature, "names"):
|
||||
label_names = label_feature.names
|
||||
else:
|
||||
unique_labels = sorted(set(dataset["train"][label_col]))
|
||||
if all(isinstance(l, str) for l in unique_labels):
|
||||
label_names = unique_labels
|
||||
else:
|
||||
label_names = [str(l) for l in unique_labels]
|
||||
|
||||
num_labels = len(label_names)
|
||||
id2label = dict(enumerate(label_names))
|
||||
label2id = {v: k for k, v in id2label.items()}
|
||||
logger.info(f"Number of classes: {num_labels}")
|
||||
|
||||
# --- Remap string labels to int if needed ---
|
||||
sample_label = dataset["train"][0][label_col]
|
||||
if isinstance(sample_label, str):
|
||||
logger.info("Remapping string labels to integer IDs")
|
||||
for split_name in list(dataset.keys()):
|
||||
dataset[split_name] = dataset[split_name].map(
|
||||
lambda ex: {label_col: label2id[ex[label_col]]},
|
||||
)
|
||||
|
||||
# --- Shuffle + Train/val split ---
|
||||
dataset["train"] = dataset["train"].shuffle(seed=training_args.seed)
|
||||
|
||||
data_args.train_val_split = None if "validation" in dataset else data_args.train_val_split
|
||||
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
|
||||
split = dataset["train"].train_test_split(data_args.train_val_split, seed=training_args.seed)
|
||||
dataset["train"] = split["train"]
|
||||
dataset["validation"] = split["test"]
|
||||
|
||||
# --- Truncate ---
|
||||
if data_args.max_train_samples is not None:
|
||||
max_train = min(data_args.max_train_samples, len(dataset["train"]))
|
||||
dataset["train"] = dataset["train"].select(range(max_train))
|
||||
logger.info(f"Truncated training set to {max_train} samples")
|
||||
if data_args.max_eval_samples is not None and "validation" in dataset:
|
||||
max_eval = min(data_args.max_eval_samples, len(dataset["validation"]))
|
||||
dataset["validation"] = dataset["validation"].select(range(max_eval))
|
||||
logger.info(f"Truncated validation set to {max_eval} samples")
|
||||
|
||||
# --- Load model & image processor ---
|
||||
common_pretrained_args = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"token": model_args.token,
|
||||
"trust_remote_code": model_args.trust_remote_code,
|
||||
}
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name or model_args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
label2id=label2id,
|
||||
id2label=id2label,
|
||||
**common_pretrained_args,
|
||||
)
|
||||
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
config=config,
|
||||
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
|
||||
**common_pretrained_args,
|
||||
)
|
||||
|
||||
image_processor = AutoImageProcessor.from_pretrained(
|
||||
model_args.image_processor_name or model_args.model_name_or_path,
|
||||
**common_pretrained_args,
|
||||
)
|
||||
|
||||
# --- Build transforms ---
|
||||
train_transforms = build_transforms(image_processor, is_training=True)
|
||||
val_transforms = build_transforms(image_processor, is_training=False)
|
||||
|
||||
image_col = data_args.image_column_name
|
||||
|
||||
def preprocess_train(examples):
|
||||
return {
|
||||
"pixel_values": [train_transforms(img.convert("RGB")) for img in examples[image_col]],
|
||||
"labels": examples[label_col],
|
||||
}
|
||||
|
||||
def preprocess_val(examples):
|
||||
return {
|
||||
"pixel_values": [val_transforms(img.convert("RGB")) for img in examples[image_col]],
|
||||
"labels": examples[label_col],
|
||||
}
|
||||
|
||||
dataset["train"].set_transform(preprocess_train)
|
||||
if "validation" in dataset:
|
||||
dataset["validation"].set_transform(preprocess_val)
|
||||
if "test" in dataset:
|
||||
dataset["test"].set_transform(preprocess_val)
|
||||
|
||||
# --- Metrics ---
|
||||
accuracy_metric = evaluate.load("accuracy")
|
||||
|
||||
def compute_metrics(eval_pred: EvalPrediction):
|
||||
predictions = np.argmax(eval_pred.predictions, axis=1)
|
||||
return accuracy_metric.compute(predictions=predictions, references=eval_pred.label_ids)
|
||||
|
||||
# --- Trainer ---
|
||||
eval_dataset = None
|
||||
if training_args.do_eval:
|
||||
if "validation" in dataset:
|
||||
eval_dataset = dataset["validation"]
|
||||
elif "test" in dataset:
|
||||
eval_dataset = dataset["test"]
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset["train"] if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=image_processor,
|
||||
data_collator=DefaultDataCollator(),
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
# --- Train ---
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# --- Evaluate ---
|
||||
if training_args.do_eval:
|
||||
test_dataset = dataset.get("test", dataset.get("validation"))
|
||||
test_prefix = "test" if "test" in dataset else "eval"
|
||||
if test_dataset is not None:
|
||||
metrics = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix=test_prefix)
|
||||
trainer.log_metrics(test_prefix, metrics)
|
||||
trainer.save_metrics(test_prefix, metrics)
|
||||
|
||||
trackio.finish()
|
||||
|
||||
# --- Push to Hub ---
|
||||
kwargs = {
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": data_args.dataset_name,
|
||||
"tags": ["image-classification", "vision"],
|
||||
}
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**kwargs)
|
||||
else:
|
||||
trainer.create_model_card(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,710 @@
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers>=5.2.0",
|
||||
# "accelerate>=1.1.0",
|
||||
# "albumentations >= 1.4.16",
|
||||
# "timm",
|
||||
# "datasets>=4.0",
|
||||
# "torchmetrics",
|
||||
# "pycocotools",
|
||||
# "trackio",
|
||||
# "huggingface_hub",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""Finetuning any 🤗 Transformers model supported by AutoModelForObjectDetection for object detection leveraging the Trainer API."""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import albumentations as A
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||||
|
||||
import trackio
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoImageProcessor,
|
||||
AutoModelForObjectDetection,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.image_processing_utils import BatchFeature
|
||||
from transformers.image_transforms import center_to_corners_format
|
||||
from transformers.trainer import EvalPrediction
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.57.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelOutput:
|
||||
logits: torch.Tensor
|
||||
pred_boxes: torch.Tensor
|
||||
|
||||
|
||||
def format_image_annotations_as_coco(
|
||||
image_id: str, categories: list[int], areas: list[float], bboxes: list[tuple[float]]
|
||||
) -> dict:
|
||||
"""Format one set of image annotations to the COCO format
|
||||
|
||||
Args:
|
||||
image_id (str): image id. e.g. "0001"
|
||||
categories (list[int]): list of categories/class labels corresponding to provided bounding boxes
|
||||
areas (list[float]): list of corresponding areas to provided bounding boxes
|
||||
bboxes (list[tuple[float]]): list of bounding boxes provided in COCO format
|
||||
([center_x, center_y, width, height] in absolute coordinates)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"image_id": image id,
|
||||
"annotations": list of formatted annotations
|
||||
}
|
||||
"""
|
||||
annotations = []
|
||||
for category, area, bbox in zip(categories, areas, bboxes):
|
||||
formatted_annotation = {
|
||||
"image_id": image_id,
|
||||
"category_id": category,
|
||||
"iscrowd": 0,
|
||||
"area": area,
|
||||
"bbox": list(bbox),
|
||||
}
|
||||
annotations.append(formatted_annotation)
|
||||
|
||||
return {
|
||||
"image_id": image_id,
|
||||
"annotations": annotations,
|
||||
}
|
||||
|
||||
|
||||
def detect_bbox_format_from_samples(dataset, image_col="image", objects_col="objects", num_samples=50):
|
||||
"""
|
||||
Detect whether bboxes are xyxy (Pascal VOC) or xywh (COCO) by checking
|
||||
bbox coordinates against image dimensions. The correct format interpretation
|
||||
should keep bboxes within image bounds.
|
||||
"""
|
||||
exceeds_if_xywh = 0
|
||||
exceeds_if_xyxy = 0
|
||||
total = 0
|
||||
|
||||
for example in dataset.select(range(min(num_samples, len(dataset)))):
|
||||
img_w, img_h = example[image_col].size
|
||||
for bbox in example[objects_col]["bbox"]:
|
||||
if len(bbox) != 4:
|
||||
continue
|
||||
a, b, c, d = float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])
|
||||
total += 1
|
||||
|
||||
# If 3rd < 1st or 4th < 2nd, can't be xyxy (x_max must exceed x_min)
|
||||
if c < a or d < b:
|
||||
return "xywh"
|
||||
|
||||
# xywh: right/bottom edge = origin + size; exceeding image → wrong format
|
||||
if a + c > img_w * 1.05:
|
||||
exceeds_if_xywh += 1
|
||||
if b + d > img_h * 1.05:
|
||||
exceeds_if_xywh += 1
|
||||
# xyxy: right/bottom edge = coordinate itself
|
||||
if c > img_w * 1.05:
|
||||
exceeds_if_xyxy += 1
|
||||
if d > img_h * 1.05:
|
||||
exceeds_if_xyxy += 1
|
||||
|
||||
if total == 0:
|
||||
return "xywh"
|
||||
|
||||
fmt = "xyxy" if exceeds_if_xywh > exceeds_if_xyxy else "xywh"
|
||||
logger.info(
|
||||
f"Detected bbox format: {fmt} (checked {total} bboxes from {min(num_samples, len(dataset))} images)"
|
||||
)
|
||||
return fmt
|
||||
|
||||
|
||||
def sanitize_dataset(dataset, bbox_format="xywh", image_col="image", objects_col="objects"):
|
||||
"""
|
||||
Validate bboxes, convert xyxy→xywh if needed, clip to image bounds, and remove
|
||||
entries with non-finite values, non-positive dimensions, or degenerate area (<1 px).
|
||||
Drops images with no remaining valid bboxes.
|
||||
"""
|
||||
convert_xyxy = bbox_format == "xyxy"
|
||||
|
||||
def _validate(example):
|
||||
img_w, img_h = example[image_col].size
|
||||
objects = example[objects_col]
|
||||
bboxes = objects["bbox"]
|
||||
n = len(bboxes)
|
||||
|
||||
valid_indices = []
|
||||
converted_bboxes = []
|
||||
|
||||
for i, bbox in enumerate(bboxes):
|
||||
if len(bbox) != 4:
|
||||
continue
|
||||
vals = [float(v) for v in bbox]
|
||||
if not all(math.isfinite(v) for v in vals):
|
||||
continue
|
||||
|
||||
if convert_xyxy:
|
||||
x_min, y_min, x_max, y_max = vals
|
||||
w, h = x_max - x_min, y_max - y_min
|
||||
else:
|
||||
x_min, y_min, w, h = vals
|
||||
|
||||
if w <= 0 or h <= 0:
|
||||
continue
|
||||
|
||||
x_min, y_min = max(0.0, x_min), max(0.0, y_min)
|
||||
if x_min >= img_w or y_min >= img_h:
|
||||
continue
|
||||
w = min(w, img_w - x_min)
|
||||
h = min(h, img_h - y_min)
|
||||
|
||||
if w * h < 1.0:
|
||||
continue
|
||||
|
||||
valid_indices.append(i)
|
||||
converted_bboxes.append([x_min, y_min, w, h])
|
||||
|
||||
# Rebuild objects dict, filtering all list-valued fields by valid_indices
|
||||
new_objects = {}
|
||||
for key, value in objects.items():
|
||||
if key == "bbox":
|
||||
new_objects["bbox"] = converted_bboxes
|
||||
elif isinstance(value, list) and len(value) == n:
|
||||
new_objects[key] = [value[j] for j in valid_indices]
|
||||
else:
|
||||
new_objects[key] = value
|
||||
|
||||
if "area" not in new_objects or len(new_objects.get("area", [])) != len(converted_bboxes):
|
||||
new_objects["area"] = [b[2] * b[3] for b in converted_bboxes]
|
||||
|
||||
example[objects_col] = new_objects
|
||||
return example
|
||||
|
||||
before = len(dataset)
|
||||
dataset = dataset.map(_validate)
|
||||
dataset = dataset.filter(lambda ex: len(ex[objects_col]["bbox"]) > 0)
|
||||
after = len(dataset)
|
||||
if before != after:
|
||||
logger.warning(f"Dropped {before - after}/{before} images with no valid bboxes after sanitization")
|
||||
logger.info(f"Bbox sanitization complete: {after} images with valid bboxes remain")
|
||||
return dataset
|
||||
|
||||
|
||||
def convert_bbox_yolo_to_pascal(boxes: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
"""
|
||||
Convert bounding boxes from YOLO format (x_center, y_center, width, height) in range [0, 1]
|
||||
to Pascal VOC format (x_min, y_min, x_max, y_max) in absolute coordinates.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): Bounding boxes in YOLO format
|
||||
image_size (tuple[int, int]): Image size in format (height, width)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Bounding boxes in Pascal VOC format (x_min, y_min, x_max, y_max)
|
||||
"""
|
||||
# convert center to corners format
|
||||
boxes = center_to_corners_format(boxes)
|
||||
|
||||
|
||||
if isinstance(image_size, torch.Tensor):
|
||||
image_size = image_size.tolist()
|
||||
elif isinstance(image_size, np.ndarray):
|
||||
image_size = image_size.tolist()
|
||||
height, width = image_size
|
||||
boxes = boxes * torch.tensor([[width, height, width, height]])
|
||||
|
||||
return boxes
|
||||
|
||||
|
||||
def augment_and_transform_batch(
|
||||
examples: Mapping[str, Any],
|
||||
transform: A.Compose,
|
||||
image_processor: AutoImageProcessor,
|
||||
return_pixel_mask: bool = False,
|
||||
) -> BatchFeature:
|
||||
"""Apply augmentations and format annotations in COCO format for object detection task"""
|
||||
|
||||
images = []
|
||||
annotations = []
|
||||
image_ids = examples["image_id"] if "image_id" in examples else range(len(examples["image"]))
|
||||
for image_id, image, objects in zip(image_ids, examples["image"], examples["objects"]):
|
||||
image = np.array(image.convert("RGB"))
|
||||
|
||||
# Filter invalid bboxes before augmentation (safety net after sanitize_dataset)
|
||||
bboxes = objects["bbox"]
|
||||
categories = objects["category"]
|
||||
areas = objects["area"]
|
||||
valid = [
|
||||
(b, c, a)
|
||||
for b, c, a in zip(bboxes, categories, areas)
|
||||
if len(b) == 4 and b[2] > 0 and b[3] > 0 and b[0] >= 0 and b[1] >= 0
|
||||
]
|
||||
if valid:
|
||||
bboxes, categories, areas = zip(*valid)
|
||||
else:
|
||||
bboxes, categories, areas = [], [], []
|
||||
|
||||
# apply augmentations
|
||||
output = transform(image=image, bboxes=list(bboxes), category=list(categories))
|
||||
images.append(output["image"])
|
||||
|
||||
# format annotations in COCO format (recompute areas from post-augmentation bboxes)
|
||||
post_areas = [b[2] * b[3] for b in output["bboxes"]] if output["bboxes"] else []
|
||||
formatted_annotations = format_image_annotations_as_coco(
|
||||
image_id, output["category"], post_areas, output["bboxes"]
|
||||
)
|
||||
annotations.append(formatted_annotations)
|
||||
|
||||
# Apply the image processor transformations: resizing, rescaling, normalization
|
||||
result = image_processor(images=images, annotations=annotations, return_tensors="pt")
|
||||
|
||||
if not return_pixel_mask:
|
||||
result.pop("pixel_mask", None)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def collate_fn(batch: list[BatchFeature]) -> Mapping[str, torch.Tensor | list[Any]]:
|
||||
data = {}
|
||||
data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
|
||||
data["labels"] = [x["labels"] for x in batch]
|
||||
if "pixel_mask" in batch[0]:
|
||||
data["pixel_mask"] = torch.stack([x["pixel_mask"] for x in batch])
|
||||
return data
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_metrics(
|
||||
evaluation_results: EvalPrediction,
|
||||
image_processor: AutoImageProcessor,
|
||||
threshold: float = 0.0,
|
||||
id2label: Mapping[int, str] | None = None,
|
||||
) -> Mapping[str, float]:
|
||||
"""
|
||||
Compute mean average mAP, mAR and their variants for the object detection task.
|
||||
|
||||
Args:
|
||||
evaluation_results (EvalPrediction): Predictions and targets from evaluation.
|
||||
threshold (float, optional): Threshold to filter predicted boxes by confidence. Defaults to 0.0.
|
||||
id2label (Optional[dict], optional): Mapping from class id to class name. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Mapping[str, float]: Metrics in a form of dictionary {<metric_name>: <metric_value>}
|
||||
"""
|
||||
|
||||
predictions, targets = evaluation_results.predictions, evaluation_results.label_ids
|
||||
|
||||
# For metric computation we need to provide:
|
||||
# - targets in a form of list of dictionaries with keys "boxes", "labels"
|
||||
# - predictions in a form of list of dictionaries with keys "boxes", "scores", "labels"
|
||||
|
||||
image_sizes = []
|
||||
post_processed_targets = []
|
||||
post_processed_predictions = []
|
||||
|
||||
# Collect targets in the required format for metric computation
|
||||
for batch in targets:
|
||||
# collect image sizes, we will need them for predictions post processing
|
||||
batch_image_sizes = torch.tensor([x["orig_size"] for x in batch])
|
||||
image_sizes.append(batch_image_sizes)
|
||||
# collect targets in the required format for metric computation
|
||||
# boxes were converted to YOLO format needed for model training
|
||||
# here we will convert them to Pascal VOC format (x_min, y_min, x_max, y_max)
|
||||
for image_target in batch:
|
||||
boxes = torch.tensor(image_target["boxes"])
|
||||
boxes = convert_bbox_yolo_to_pascal(boxes, image_target["orig_size"])
|
||||
labels = torch.tensor(image_target["class_labels"])
|
||||
post_processed_targets.append({"boxes": boxes, "labels": labels})
|
||||
|
||||
# Collect predictions in the required format for metric computation,
|
||||
# model produce boxes in YOLO format, then image_processor convert them to Pascal VOC format
|
||||
for batch, target_sizes in zip(predictions, image_sizes):
|
||||
batch_logits, batch_boxes = batch[1], batch[2]
|
||||
output = ModelOutput(logits=torch.tensor(batch_logits), pred_boxes=torch.tensor(batch_boxes))
|
||||
post_processed_output = image_processor.post_process_object_detection(
|
||||
output, threshold=threshold, target_sizes=target_sizes
|
||||
)
|
||||
post_processed_predictions.extend(post_processed_output)
|
||||
|
||||
# Compute metrics
|
||||
metric = MeanAveragePrecision(box_format="xyxy", class_metrics=True)
|
||||
metric.update(post_processed_predictions, post_processed_targets)
|
||||
metrics = metric.compute()
|
||||
|
||||
# Replace list of per class metrics with separate metric for each class
|
||||
classes = metrics.pop("classes")
|
||||
map_per_class = metrics.pop("map_per_class")
|
||||
mar_100_per_class = metrics.pop("mar_100_per_class")
|
||||
# Single-class datasets return 0-d scalar tensors; make them iterable
|
||||
if classes.dim() == 0:
|
||||
classes = classes.unsqueeze(0)
|
||||
map_per_class = map_per_class.unsqueeze(0)
|
||||
mar_100_per_class = mar_100_per_class.unsqueeze(0)
|
||||
for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
|
||||
class_name = id2label[class_id.item()] if id2label is not None else class_id.item()
|
||||
metrics[f"map_{class_name}"] = class_map
|
||||
metrics[f"mar_100_{class_name}"] = class_mar
|
||||
|
||||
metrics = {k: round(v.item(), 4) for k, v in metrics.items()}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify
|
||||
them on the command line.
|
||||
"""
|
||||
|
||||
dataset_name: str = field(
|
||||
default="cppe-5",
|
||||
metadata={
|
||||
"help": "Name of a dataset from the hub (could be your own, possibly private dataset hosted on the hub)."
|
||||
},
|
||||
)
|
||||
dataset_config_name: str | None = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_val_split: float | None = field(
|
||||
default=0.15, metadata={"help": "Percent to split off of train for validation."}
|
||||
)
|
||||
image_square_size: int | None = field(
|
||||
default=600,
|
||||
metadata={"help": "Image longest size will be resized to this value, then image will be padded to square."},
|
||||
)
|
||||
max_train_samples: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_eval_samples: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
use_fast: bool | None = field(
|
||||
default=True,
|
||||
metadata={"help": "Use a fast torchvision-base image processor if it is supported for a given model."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
default="facebook/detr-resnet-50",
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
|
||||
)
|
||||
config_name: str | None = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: str | None = field(
|
||||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
|
||||
ignore_mismatched_sizes: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a checkpoint with 3 labels)."
|
||||
},
|
||||
)
|
||||
token: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
|
||||
"generated when running `hf auth login` (stored in `~/.huggingface`)."
|
||||
)
|
||||
},
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"Whether to trust the execution of code from datasets/models defined on the Hub."
|
||||
" This option should only be set to `True` for repositories you trust and in which you have read the"
|
||||
" code, as it will execute code present on the Hub on your local machine."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
|
||||
from huggingface_hub import login
|
||||
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("hfjob")
|
||||
if hf_token:
|
||||
login(token=hf_token)
|
||||
training_args.hub_token = hf_token
|
||||
logger.info("Logged in to Hugging Face Hub")
|
||||
elif training_args.push_to_hub:
|
||||
logger.warning("HF_TOKEN not found in environment. Hub push will likely fail.")
|
||||
|
||||
# Initialize Trackio for real-time experiment tracking
|
||||
trackio.init(project=training_args.output_dir, name=training_args.run_name)
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
if training_args.should_log:
|
||||
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_process_index}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
|
||||
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
dataset = load_dataset(
|
||||
data_args.dataset_name, cache_dir=model_args.cache_dir, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
|
||||
bbox_format = detect_bbox_format_from_samples(dataset["train"])
|
||||
if bbox_format == "xyxy":
|
||||
logger.info("Converting bboxes from xyxy (Pascal VOC) → xywh (COCO) format across all splits")
|
||||
for split_name in list(dataset.keys()):
|
||||
dataset[split_name] = sanitize_dataset(dataset[split_name], bbox_format=bbox_format)
|
||||
|
||||
for split_name in list(dataset.keys()):
|
||||
if "image_id" not in dataset[split_name].column_names:
|
||||
dataset[split_name] = dataset[split_name].add_column(
|
||||
"image_id", list(range(len(dataset[split_name])))
|
||||
)
|
||||
|
||||
dataset["train"] = dataset["train"].shuffle(seed=training_args.seed)
|
||||
|
||||
data_args.train_val_split = None if "validation" in dataset else data_args.train_val_split
|
||||
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
|
||||
split = dataset["train"].train_test_split(data_args.train_val_split, seed=training_args.seed)
|
||||
dataset["train"] = split["train"]
|
||||
dataset["validation"] = split["test"]
|
||||
|
||||
categories = None
|
||||
try:
|
||||
if isinstance(dataset["train"].features["objects"], dict):
|
||||
cat_feature = dataset["train"].features["objects"]["category"].feature
|
||||
else:
|
||||
cat_feature = dataset["train"].features["objects"].feature["category"]
|
||||
|
||||
if hasattr(cat_feature, "names"):
|
||||
categories = cat_feature.names
|
||||
except (AttributeError, KeyError):
|
||||
pass
|
||||
|
||||
if categories is None:
|
||||
# Category is a Value type (not ClassLabel) — scan dataset to discover labels
|
||||
logger.info("Category feature is not ClassLabel — scanning dataset to discover category labels...")
|
||||
unique_cats = set()
|
||||
for example in dataset["train"]:
|
||||
cats = example["objects"]["category"]
|
||||
if isinstance(cats, list):
|
||||
unique_cats.update(cats)
|
||||
else:
|
||||
unique_cats.add(cats)
|
||||
|
||||
if all(isinstance(c, int) for c in unique_cats):
|
||||
max_cat = max(unique_cats)
|
||||
categories = [f"class_{i}" for i in range(max_cat + 1)]
|
||||
elif all(isinstance(c, str) for c in unique_cats):
|
||||
categories = sorted(unique_cats)
|
||||
else:
|
||||
categories = [str(c) for c in sorted(unique_cats, key=str)]
|
||||
logger.info(f"Discovered {len(categories)} categories: {categories}")
|
||||
|
||||
id2label = dict(enumerate(categories))
|
||||
label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
# Remap string categories to integer IDs if needed
|
||||
sample_cats = dataset["train"][0]["objects"]["category"]
|
||||
if sample_cats and isinstance(sample_cats[0], str):
|
||||
logger.info(f"Remapping string categories to integer IDs: {label2id}")
|
||||
|
||||
def _remap_categories(example):
|
||||
objects = example["objects"]
|
||||
objects["category"] = [label2id[c] for c in objects["category"]]
|
||||
example["objects"] = objects
|
||||
return example
|
||||
|
||||
for split_name in list(dataset.keys()):
|
||||
dataset[split_name] = dataset[split_name].map(_remap_categories)
|
||||
logger.info("Category remapping complete")
|
||||
|
||||
if data_args.max_train_samples is not None:
|
||||
max_train = min(data_args.max_train_samples, len(dataset["train"]))
|
||||
dataset["train"] = dataset["train"].select(range(max_train))
|
||||
logger.info(f"Truncated training set to {max_train} samples")
|
||||
if data_args.max_eval_samples is not None and "validation" in dataset:
|
||||
max_eval = min(data_args.max_eval_samples, len(dataset["validation"]))
|
||||
dataset["validation"] = dataset["validation"].select(range(max_eval))
|
||||
logger.info(f"Truncated validation set to {max_eval} samples")
|
||||
|
||||
common_pretrained_args = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"token": model_args.token,
|
||||
"trust_remote_code": model_args.trust_remote_code,
|
||||
}
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name or model_args.model_name_or_path,
|
||||
label2id=label2id,
|
||||
id2label=id2label,
|
||||
**common_pretrained_args,
|
||||
)
|
||||
model = AutoModelForObjectDetection.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
config=config,
|
||||
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
|
||||
**common_pretrained_args,
|
||||
)
|
||||
image_processor = AutoImageProcessor.from_pretrained(
|
||||
model_args.image_processor_name or model_args.model_name_or_path,
|
||||
do_resize=True,
|
||||
size={"max_height": data_args.image_square_size, "max_width": data_args.image_square_size},
|
||||
do_pad=True,
|
||||
pad_size={"height": data_args.image_square_size, "width": data_args.image_square_size},
|
||||
use_fast=data_args.use_fast,
|
||||
**common_pretrained_args,
|
||||
)
|
||||
|
||||
max_size = data_args.image_square_size
|
||||
train_augment_and_transform = A.Compose(
|
||||
[
|
||||
A.Compose(
|
||||
[
|
||||
A.SmallestMaxSize(max_size=max_size, p=1.0),
|
||||
A.RandomSizedBBoxSafeCrop(height=max_size, width=max_size, p=1.0),
|
||||
],
|
||||
p=0.2,
|
||||
),
|
||||
A.OneOf(
|
||||
[
|
||||
A.Blur(blur_limit=7, p=0.5),
|
||||
A.MotionBlur(blur_limit=7, p=0.5),
|
||||
A.Defocus(radius=(1, 5), alias_blur=(0.1, 0.25), p=0.1),
|
||||
],
|
||||
p=0.1,
|
||||
),
|
||||
A.Perspective(p=0.1),
|
||||
A.HorizontalFlip(p=0.5),
|
||||
A.RandomBrightnessContrast(p=0.5),
|
||||
A.HueSaturationValue(p=0.1),
|
||||
],
|
||||
bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True, min_area=25),
|
||||
)
|
||||
validation_transform = A.Compose(
|
||||
[A.NoOp()],
|
||||
bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True),
|
||||
)
|
||||
|
||||
train_transform_batch = partial(
|
||||
augment_and_transform_batch, transform=train_augment_and_transform, image_processor=image_processor
|
||||
)
|
||||
validation_transform_batch = partial(
|
||||
augment_and_transform_batch, transform=validation_transform, image_processor=image_processor
|
||||
)
|
||||
|
||||
dataset["train"] = dataset["train"].with_transform(train_transform_batch)
|
||||
dataset["validation"] = dataset["validation"].with_transform(validation_transform_batch)
|
||||
if "test" in dataset:
|
||||
dataset["test"] = dataset["test"].with_transform(validation_transform_batch)
|
||||
|
||||
|
||||
eval_compute_metrics_fn = partial(
|
||||
compute_metrics, image_processor=image_processor, id2label=id2label, threshold=0.0
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset["train"] if training_args.do_train else None,
|
||||
eval_dataset=dataset["validation"] if training_args.do_eval else None,
|
||||
processing_class=image_processor,
|
||||
data_collator=collate_fn,
|
||||
compute_metrics=eval_compute_metrics_fn,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
|
||||
if training_args.do_eval:
|
||||
test_dataset = dataset["test"] if "test" in dataset else dataset["validation"]
|
||||
test_prefix = "test" if "test" in dataset else "eval"
|
||||
metrics = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix=test_prefix)
|
||||
trainer.log_metrics(test_prefix, metrics)
|
||||
trainer.save_metrics(test_prefix, metrics)
|
||||
|
||||
trackio.finish()
|
||||
|
||||
kwargs = {
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": data_args.dataset_name,
|
||||
"tags": ["object-detection", "vision"],
|
||||
}
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**kwargs)
|
||||
else:
|
||||
trainer.create_model_card(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,382 @@
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers>=5.2.0",
|
||||
# "accelerate>=1.1.0",
|
||||
# "datasets>=4.0",
|
||||
# "torchvision",
|
||||
# "monai",
|
||||
# "trackio",
|
||||
# "huggingface_hub",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""Fine-tune SAM or SAM2 for segmentation using bounding-box or point prompts with the HF Trainer API."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import monai
|
||||
import trackio
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.utils import check_min_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
check_min_version("4.57.0.dev0")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset wrapper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SAMSegmentationDataset(Dataset):
|
||||
"""Wraps a HF dataset into the format expected by SAM/SAM2 processors.
|
||||
|
||||
Each sample must contain an image, a binary mask, and a prompt (bbox or
|
||||
point). Prompts are read from a JSON-encoded ``prompt`` column or from
|
||||
dedicated ``bbox`` / ``point`` columns.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, processor, prompt_type: str,
|
||||
image_col: str, mask_col: str, prompt_col: str | None,
|
||||
bbox_col: str | None, point_col: str | None):
|
||||
self.dataset = dataset
|
||||
self.processor = processor
|
||||
self.prompt_type = prompt_type
|
||||
self.image_col = image_col
|
||||
self.mask_col = mask_col
|
||||
self.prompt_col = prompt_col
|
||||
self.bbox_col = bbox_col
|
||||
self.point_col = point_col
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def _extract_prompt(self, item):
|
||||
if self.prompt_col and self.prompt_col in item:
|
||||
raw = item[self.prompt_col]
|
||||
parsed = json.loads(raw) if isinstance(raw, str) else raw
|
||||
if self.prompt_type == "bbox":
|
||||
return parsed.get("bbox") or parsed.get("box")
|
||||
return parsed.get("point") or parsed.get("points")
|
||||
|
||||
if self.prompt_type == "bbox" and self.bbox_col:
|
||||
return item[self.bbox_col]
|
||||
if self.prompt_type == "point" and self.point_col:
|
||||
return item[self.point_col]
|
||||
raise ValueError("Could not extract prompt from sample")
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.dataset[idx]
|
||||
image = item[self.image_col]
|
||||
prompt = self._extract_prompt(item)
|
||||
|
||||
if self.prompt_type == "bbox":
|
||||
inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")
|
||||
else:
|
||||
if isinstance(prompt[0], (int, float)):
|
||||
prompt = [prompt]
|
||||
inputs = self.processor(image, input_points=[[prompt]], return_tensors="pt")
|
||||
|
||||
mask = np.array(item[self.mask_col])
|
||||
if mask.ndim == 3:
|
||||
mask = mask[:, :, 0]
|
||||
inputs["labels"] = (mask > 0).astype(np.float32)
|
||||
inputs["original_image_size"] = torch.tensor(image.size[::-1])
|
||||
return inputs
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
pixel_values = torch.cat([item["pixel_values"] for item in batch], dim=0)
|
||||
original_sizes = torch.stack([item["original_sizes"] for item in batch])
|
||||
original_image_size = torch.stack([item["original_image_size"] for item in batch])
|
||||
|
||||
has_boxes = "input_boxes" in batch[0]
|
||||
has_points = "input_points" in batch[0]
|
||||
|
||||
labels = torch.cat(
|
||||
[
|
||||
F.interpolate(
|
||||
torch.as_tensor(x["labels"]).unsqueeze(0).unsqueeze(0).float(),
|
||||
size=(256, 256),
|
||||
mode="nearest",
|
||||
)
|
||||
for x in batch
|
||||
],
|
||||
dim=0,
|
||||
).long()
|
||||
|
||||
result = {
|
||||
"pixel_values": pixel_values,
|
||||
"original_sizes": original_sizes,
|
||||
"labels": labels,
|
||||
"original_image_size": original_image_size,
|
||||
"multimask_output": False,
|
||||
}
|
||||
|
||||
if has_boxes:
|
||||
result["input_boxes"] = torch.cat([item["input_boxes"] for item in batch], dim=0)
|
||||
if has_points:
|
||||
result["input_points"] = torch.cat([item["input_points"] for item in batch], dim=0)
|
||||
if "input_labels" in batch[0]:
|
||||
result["input_labels"] = torch.cat([item["input_labels"] for item in batch], dim=0)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom loss (SAM/SAM2 don't compute loss in forward())
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction="mean")
|
||||
|
||||
|
||||
def compute_loss(outputs, labels, num_items_in_batch=None):
|
||||
predicted_masks = outputs.pred_masks.squeeze(1)
|
||||
return seg_loss(predicted_masks, labels.float())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI arguments
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
dataset_name: str = field(
|
||||
default="merve/MicroMat-mini",
|
||||
metadata={"help": "Hub dataset ID."},
|
||||
)
|
||||
dataset_config_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Dataset config name."},
|
||||
)
|
||||
train_val_split: float | None = field(
|
||||
default=0.1,
|
||||
metadata={"help": "Fraction to split off for validation (used when no validation split exists)."},
|
||||
)
|
||||
max_train_samples: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Truncate training set (for quick tests)."},
|
||||
)
|
||||
max_eval_samples: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Truncate evaluation set."},
|
||||
)
|
||||
image_column_name: str = field(
|
||||
default="image",
|
||||
metadata={"help": "Column containing PIL images."},
|
||||
)
|
||||
mask_column_name: str = field(
|
||||
default="mask",
|
||||
metadata={"help": "Column containing ground-truth binary masks."},
|
||||
)
|
||||
prompt_column_name: str | None = field(
|
||||
default="prompt",
|
||||
metadata={"help": "Column with JSON-encoded prompt (bbox/point). Set to '' to disable."},
|
||||
)
|
||||
bbox_column_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Column with bbox prompt ([x0,y0,x1,y1]). Used when prompt_column_name is unset."},
|
||||
)
|
||||
point_column_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Column with point prompt ([x,y] or [[x,y],...]). Used when prompt_column_name is unset."},
|
||||
)
|
||||
prompt_type: str = field(
|
||||
default="bbox",
|
||||
metadata={"help": "Prompt type: 'bbox' or 'point'."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: str = field(
|
||||
default="facebook/sam2.1-hiera-small",
|
||||
metadata={"help": "Pretrained SAM/SAM2 model identifier."},
|
||||
)
|
||||
cache_dir: str | None = field(default=None, metadata={"help": "Cache directory."})
|
||||
model_revision: str = field(default="main", metadata={"help": "Model revision."})
|
||||
token: str | None = field(default=None, metadata={"help": "Auth token."})
|
||||
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code."})
|
||||
freeze_vision_encoder: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Freeze vision encoder weights."},
|
||||
)
|
||||
freeze_prompt_encoder: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Freeze prompt encoder weights."},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
parser.set_defaults(per_device_train_batch_size=4, num_train_epochs=30)
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
model_args, data_args, training_args = parser.parse_json_file(
|
||||
json_file=os.path.abspath(sys.argv[1])
|
||||
)
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
from huggingface_hub import login
|
||||
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("hfjob")
|
||||
if hf_token:
|
||||
login(token=hf_token)
|
||||
training_args.hub_token = hf_token
|
||||
logger.info("Logged in to Hugging Face Hub")
|
||||
elif training_args.push_to_hub:
|
||||
logger.warning("HF_TOKEN not found in environment. Hub push will likely fail.")
|
||||
|
||||
trackio.init(project=training_args.output_dir, name=training_args.run_name)
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
if training_args.should_log:
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# ---- Load dataset ----
|
||||
dataset = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
)
|
||||
|
||||
if "train" not in dataset:
|
||||
if len(dataset.keys()) == 1:
|
||||
only_split = list(dataset.keys())[0]
|
||||
dataset[only_split] = dataset[only_split].shuffle(seed=training_args.seed)
|
||||
dataset = dataset[only_split].train_test_split(test_size=data_args.train_val_split or 0.1)
|
||||
dataset = {"train": dataset["train"], "validation": dataset["test"]}
|
||||
else:
|
||||
raise ValueError(f"No 'train' split found. Available: {list(dataset.keys())}")
|
||||
elif "validation" not in dataset and "test" not in dataset:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=training_args.seed)
|
||||
split = dataset["train"].train_test_split(
|
||||
test_size=data_args.train_val_split or 0.1, seed=training_args.seed
|
||||
)
|
||||
dataset["train"] = split["train"]
|
||||
dataset["validation"] = split["test"]
|
||||
|
||||
if data_args.max_train_samples is not None:
|
||||
n = min(data_args.max_train_samples, len(dataset["train"]))
|
||||
dataset["train"] = dataset["train"].select(range(n))
|
||||
logger.info(f"Truncated training set to {n} samples")
|
||||
eval_key = "validation" if "validation" in dataset else "test"
|
||||
if data_args.max_eval_samples is not None and eval_key in dataset:
|
||||
n = min(data_args.max_eval_samples, len(dataset[eval_key]))
|
||||
dataset[eval_key] = dataset[eval_key].select(range(n))
|
||||
logger.info(f"Truncated eval set to {n} samples")
|
||||
|
||||
# ---- Detect model family (SAM vs SAM2) and load processor/model ----
|
||||
model_id = model_args.model_name_or_path.lower()
|
||||
is_sam2 = "sam2" in model_id
|
||||
|
||||
if is_sam2:
|
||||
from transformers import Sam2Processor, Sam2Model
|
||||
processor = Sam2Processor.from_pretrained(model_args.model_name_or_path)
|
||||
model = Sam2Model.from_pretrained(model_args.model_name_or_path)
|
||||
else:
|
||||
from transformers import SamProcessor, SamModel
|
||||
processor = SamProcessor.from_pretrained(model_args.model_name_or_path)
|
||||
model = SamModel.from_pretrained(model_args.model_name_or_path)
|
||||
|
||||
if model_args.freeze_vision_encoder:
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("vision_encoder"):
|
||||
param.requires_grad_(False)
|
||||
if model_args.freeze_prompt_encoder:
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("prompt_encoder"):
|
||||
param.requires_grad_(False)
|
||||
|
||||
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
total = sum(p.numel() for p in model.parameters())
|
||||
logger.info(f"Trainable params: {trainable:,} / {total:,} ({100 * trainable / total:.1f}%)")
|
||||
|
||||
# ---- Build datasets ----
|
||||
prompt_col = data_args.prompt_column_name if data_args.prompt_column_name else None
|
||||
ds_kwargs = dict(
|
||||
processor=processor,
|
||||
prompt_type=data_args.prompt_type,
|
||||
image_col=data_args.image_column_name,
|
||||
mask_col=data_args.mask_column_name,
|
||||
prompt_col=prompt_col,
|
||||
bbox_col=data_args.bbox_column_name,
|
||||
point_col=data_args.point_column_name,
|
||||
)
|
||||
|
||||
train_dataset = SAMSegmentationDataset(dataset=dataset["train"], **ds_kwargs)
|
||||
eval_dataset = None
|
||||
if eval_key in dataset:
|
||||
eval_dataset = SAMSegmentationDataset(dataset=dataset[eval_key], **ds_kwargs)
|
||||
|
||||
# ---- Train ----
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
data_collator=collate_fn,
|
||||
compute_loss_func=compute_loss,
|
||||
)
|
||||
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
|
||||
if training_args.do_eval and eval_dataset is not None:
|
||||
metrics = trainer.evaluate()
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
trackio.finish()
|
||||
|
||||
kwargs = {
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": data_args.dataset_name,
|
||||
"tags": ["image-segmentation", "vision", "sam"],
|
||||
}
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**kwargs)
|
||||
else:
|
||||
trainer.create_model_card(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
273
plugins/antigravity-awesome-skills-claude/skills/jq/SKILL.md
Normal file
273
plugins/antigravity-awesome-skills-claude/skills/jq/SKILL.md
Normal file
@@ -0,0 +1,273 @@
|
||||
---
|
||||
name: jq
|
||||
description: "Expert jq usage for JSON querying, filtering, transformation, and pipeline integration. Practical patterns for real shell workflows."
|
||||
category: development
|
||||
risk: safe
|
||||
source: community
|
||||
date_added: "2026-03-28"
|
||||
author: kostakost2
|
||||
tags: [jq, json, shell, cli, data-transformation, bash]
|
||||
tools: [claude, cursor, gemini]
|
||||
---
|
||||
|
||||
# jq — JSON Querying and Transformation
|
||||
|
||||
## Overview
|
||||
|
||||
`jq` is the standard CLI tool for querying and reshaping JSON. This skill covers practical, expert-level usage: filtering deeply nested data, transforming structures, aggregating values, and composing `jq` into shell pipelines. Every example is copy-paste ready for real workflows.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
- Use when parsing JSON output from APIs, CLI tools (AWS, GitHub, kubectl, docker), or log files
|
||||
- Use when transforming JSON structure (rename keys, flatten arrays, group records)
|
||||
- Use when the user needs `jq` inside a bash script or one-liner
|
||||
- Use when explaining what a complex `jq` expression does
|
||||
|
||||
## How It Works
|
||||
|
||||
`jq` takes a filter expression and applies it to JSON input. Filters compose with pipes (`|`), and `jq` handles arrays, objects, strings, numbers, booleans, and `null` natively.
|
||||
|
||||
### Basic Selection
|
||||
|
||||
```bash
|
||||
# Extract a field
|
||||
echo '{"name":"alice","age":30}' | jq '.name'
|
||||
# "alice"
|
||||
|
||||
# Nested access
|
||||
echo '{"user":{"email":"a@b.com"}}' | jq '.user.email'
|
||||
|
||||
# Array index
|
||||
echo '[10, 20, 30]' | jq '.[1]'
|
||||
# 20
|
||||
|
||||
# Array slice
|
||||
echo '[1,2,3,4,5]' | jq '.[2:4]'
|
||||
# [3, 4]
|
||||
|
||||
# All array elements
|
||||
echo '[{"id":1},{"id":2}]' | jq '.[]'
|
||||
```
|
||||
|
||||
### Filtering with `select`
|
||||
|
||||
```bash
|
||||
# Keep only matching elements
|
||||
echo '[{"role":"admin"},{"role":"user"},{"role":"admin"}]' \
|
||||
| jq '[.[] | select(.role == "admin")]'
|
||||
|
||||
# Numeric comparison
|
||||
curl -s https://api.github.com/repos/owner/repo/issues \
|
||||
| jq '[.[] | select(.comments > 5)]'
|
||||
|
||||
# Test a field exists and is non-null
|
||||
jq '[.[] | select(.email != null)]'
|
||||
|
||||
# Combine conditions
|
||||
jq '[.[] | select(.active == true and .score >= 80)]'
|
||||
```
|
||||
|
||||
### Mapping and Transformation
|
||||
|
||||
```bash
|
||||
# Extract a field from every array element
|
||||
echo '[{"name":"alice","age":30},{"name":"bob","age":25}]' \
|
||||
| jq '[.[] | .name]'
|
||||
# ["alice", "bob"]
|
||||
|
||||
# Shorthand: map()
|
||||
jq 'map(.name)'
|
||||
|
||||
# Build a new object per element
|
||||
jq '[.[] | {user: .name, years: .age}]'
|
||||
|
||||
# Add a computed field
|
||||
jq '[.[] | . + {senior: (.age > 28)}]'
|
||||
|
||||
# Rename keys
|
||||
jq '[.[] | {username: .name, email_address: .email}]'
|
||||
```
|
||||
|
||||
### Aggregation and Reduce
|
||||
|
||||
```bash
|
||||
# Sum all values
|
||||
echo '[1, 2, 3, 4, 5]' | jq 'add'
|
||||
# 15
|
||||
|
||||
# Sum a field across objects
|
||||
jq '[.[].price] | add'
|
||||
|
||||
# Count elements
|
||||
jq 'length'
|
||||
|
||||
# Max / min
|
||||
jq 'max_by(.score)'
|
||||
jq 'min_by(.created_at)'
|
||||
|
||||
# reduce: custom accumulator
|
||||
echo '[1,2,3,4,5]' | jq 'reduce .[] as $x (0; . + $x)'
|
||||
# 15
|
||||
|
||||
# Group by field
|
||||
jq 'group_by(.department)'
|
||||
|
||||
# Count per group
|
||||
jq 'group_by(.status) | map({status: .[0].status, count: length})'
|
||||
```
|
||||
|
||||
### String Interpolation and Formatting
|
||||
|
||||
```bash
|
||||
# String interpolation
|
||||
jq -r '.[] | "\(.name) is \(.age) years old"'
|
||||
|
||||
# Format as CSV (no header)
|
||||
jq -r '.[] | [.name, .age, .email] | @csv'
|
||||
|
||||
# Format as TSV
|
||||
jq -r '.[] | [.name, .score] | @tsv'
|
||||
|
||||
# URL-encode a value
|
||||
jq -r '.query | @uri'
|
||||
|
||||
# Base64 encode
|
||||
jq -r '.data | @base64'
|
||||
```
|
||||
|
||||
### Working with Keys and Paths
|
||||
|
||||
```bash
|
||||
# List all top-level keys
|
||||
jq 'keys'
|
||||
|
||||
# Check if key exists
|
||||
jq 'has("email")'
|
||||
|
||||
# Delete a key
|
||||
jq 'del(.password)'
|
||||
|
||||
# Delete nested keys from every element
|
||||
jq '[.[] | del(.internal_id, .raw_payload)]'
|
||||
|
||||
# Recursive descent: find all values for a key anywhere in tree
|
||||
jq '.. | .id? // empty'
|
||||
|
||||
# Get all leaf paths
|
||||
jq '[paths(scalars)]'
|
||||
```
|
||||
|
||||
### Conditionals and Error Handling
|
||||
|
||||
```bash
|
||||
# if-then-else
|
||||
jq 'if .score >= 90 then "A" elif .score >= 80 then "B" else "C" end'
|
||||
|
||||
# Alternative operator: use fallback if null or false
|
||||
jq '.nickname // .name'
|
||||
|
||||
# try-catch: skip errors instead of halting
|
||||
jq '[.[] | try .nested.value catch null]'
|
||||
|
||||
# Suppress null output with // empty
|
||||
jq '.[] | .optional_field // empty'
|
||||
```
|
||||
|
||||
### Practical Shell Integration
|
||||
|
||||
```bash
|
||||
# Read from file
|
||||
jq '.users' data.json
|
||||
|
||||
# Compact output (no whitespace) for further piping
|
||||
jq -c '.[]' records.json | while IFS= read -r record; do
|
||||
echo "Processing: $record"
|
||||
done
|
||||
|
||||
# Pass a shell variable into jq
|
||||
STATUS="active"
|
||||
jq --arg s "$STATUS" '[.[] | select(.status == $s)]'
|
||||
|
||||
# Pass a number
|
||||
jq --argjson threshold 42 '[.[] | select(.value > $threshold)]'
|
||||
|
||||
# Slurp multiple JSON lines into an array
|
||||
jq -s '.' records.ndjson
|
||||
|
||||
# Multiple files: slurp all into one array
|
||||
jq -s 'add' file1.json file2.json
|
||||
|
||||
# Null-safe pipeline from a command
|
||||
kubectl get pods -o json | jq '.items[] | {name: .metadata.name, status: .status.phase}'
|
||||
|
||||
# GitHub CLI: extract PR numbers
|
||||
gh pr list --json number,title | jq -r '.[] | "\(.number)\t\(.title)"'
|
||||
|
||||
# AWS CLI: list running instance IDs
|
||||
aws ec2 describe-instances \
|
||||
| jq -r '.Reservations[].Instances[] | select(.State.Name=="running") | .InstanceId'
|
||||
|
||||
# Docker: show container names and images
|
||||
docker inspect $(docker ps -q) | jq -r '.[] | "\(.Name)\t\(.Config.Image)"'
|
||||
```
|
||||
|
||||
### Advanced Patterns
|
||||
|
||||
```bash
|
||||
# Transpose an object of arrays to an array of objects
|
||||
# Input: {"names":["a","b"],"scores":[10,20]}
|
||||
jq '[.names, .scores] | transpose | map({name: .[0], score: .[1]})'
|
||||
|
||||
# Flatten one level
|
||||
jq 'flatten(1)'
|
||||
|
||||
# Unique by field
|
||||
jq 'unique_by(.email)'
|
||||
|
||||
# Sort, deduplicate and re-index
|
||||
jq '[.[] | .name] | unique | sort'
|
||||
|
||||
# Walk: apply transformation to every node recursively
|
||||
jq 'walk(if type == "string" then ascii_downcase else . end)'
|
||||
|
||||
# env: read environment variables inside jq
|
||||
export API_KEY=secret
|
||||
jq -n 'env.API_KEY'
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- Always use `-r` (raw output) when passing `jq` results to shell variables or other commands to strip JSON string quotes
|
||||
- Use `--arg` / `--argjson` to inject shell variables safely — never interpolate shell variables directly into filter strings
|
||||
- Prefer `map(f)` over `[.[] | f]` for readability
|
||||
- Use `-c` (compact) for newline-delimited JSON pipelines; omit it for human-readable debugging
|
||||
- Test filters interactively with `jq -n` and literal input before embedding in scripts
|
||||
- Use `empty` to drop unwanted elements rather than filtering to `null`
|
||||
|
||||
## Security & Safety Notes
|
||||
|
||||
- `jq` is read-only by design — it cannot write files or execute commands
|
||||
- Avoid embedding untrusted JSON field values directly into shell commands; always quote or use `--arg`
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
- **Problem:** `jq` outputs `null` instead of the expected value
|
||||
**Solution:** Check for typos in key names; use `keys` to inspect actual field names. Remember JSON is case-sensitive.
|
||||
|
||||
- **Problem:** Numbers are quoted as strings in the output
|
||||
**Solution:** Use `--argjson` instead of `--arg` when injecting numeric values.
|
||||
|
||||
- **Problem:** Filter works in the terminal but fails in a script
|
||||
**Solution:** Ensure the filter string uses single quotes in the shell to prevent variable expansion. Example: `jq '.field'` not `jq ".field"`.
|
||||
|
||||
- **Problem:** `add` returns `null` on an empty array
|
||||
**Solution:** Use `add // 0` or `add // ""` to provide a fallback default.
|
||||
|
||||
- **Problem:** Streaming large files is slow
|
||||
**Solution:** Use `jq --stream` or switch to `jstream`/`gron` for very large files.
|
||||
|
||||
## Related Skills
|
||||
|
||||
- `@bash-pro` — Wrapping jq calls in robust shell scripts
|
||||
- `@bash-linux` — General shell pipeline patterns
|
||||
- `@github-automation` — Using jq with GitHub CLI JSON output
|
||||
370
plugins/antigravity-awesome-skills-claude/skills/tmux/SKILL.md
Normal file
370
plugins/antigravity-awesome-skills-claude/skills/tmux/SKILL.md
Normal file
@@ -0,0 +1,370 @@
|
||||
---
|
||||
name: tmux
|
||||
description: "Expert tmux session, window, and pane management for terminal multiplexing, persistent remote workflows, and shell scripting automation."
|
||||
category: development
|
||||
risk: safe
|
||||
source: community
|
||||
date_added: "2026-03-28"
|
||||
author: kostakost2
|
||||
tags: [tmux, terminal, multiplexer, sessions, shell, remote, automation]
|
||||
tools: [claude, cursor, gemini]
|
||||
---
|
||||
|
||||
# tmux — Terminal Multiplexer
|
||||
|
||||
## Overview
|
||||
|
||||
`tmux` keeps terminal sessions alive across SSH disconnects, splits work across multiple panes, and enables fully scriptable terminal automation. This skill covers session management, window/pane layout, keybinding patterns, and using `tmux` non-interactively from shell scripts — essential for remote servers, long-running jobs, and automated workflows.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
- Use when setting up or managing persistent terminal sessions on remote servers
|
||||
- Use when the user needs to run long-running processes that survive SSH disconnects
|
||||
- Use when scripting multi-pane terminal layouts (e.g., logs + shell + editor)
|
||||
- Use when automating `tmux` commands from bash scripts without user interaction
|
||||
|
||||
## How It Works
|
||||
|
||||
`tmux` has three hierarchy levels: **sessions** (top level, survives disconnects), **windows** (tabs within a session), and **panes** (splits within a window). Everything is controllable from outside via `tmux <command>` or from inside via the prefix key (`Ctrl-b` by default).
|
||||
|
||||
### Session Management
|
||||
|
||||
```bash
|
||||
# Create a new named session
|
||||
tmux new-session -s work
|
||||
|
||||
# Create detached (background) session
|
||||
tmux new-session -d -s work
|
||||
|
||||
# Create detached session and start a command
|
||||
tmux new-session -d -s build -x 220 -y 50 "make all"
|
||||
|
||||
# Attach to a session
|
||||
tmux attach -t work
|
||||
tmux attach # attaches to most recent session
|
||||
|
||||
# List all sessions
|
||||
tmux list-sessions
|
||||
tmux ls
|
||||
|
||||
# Detach from inside tmux
|
||||
# Prefix + d (Ctrl-b d)
|
||||
|
||||
# Kill a session
|
||||
tmux kill-session -t work
|
||||
|
||||
# Kill all sessions except the current one
|
||||
tmux kill-session -a
|
||||
|
||||
# Rename a session from outside
|
||||
tmux rename-session -t old-name new-name
|
||||
|
||||
# Switch to another session from outside
|
||||
tmux switch-client -t other-session
|
||||
|
||||
# Check if a session exists (useful in scripts)
|
||||
tmux has-session -t work 2>/dev/null && echo "exists"
|
||||
```
|
||||
|
||||
### Window Management
|
||||
|
||||
```bash
|
||||
# Create a new window in the current session
|
||||
tmux new-window -t work -n "logs"
|
||||
|
||||
# Create a window running a specific command
|
||||
tmux new-window -t work:3 -n "server" "python -m http.server 8080"
|
||||
|
||||
# List windows
|
||||
tmux list-windows -t work
|
||||
|
||||
# Select (switch to) a window
|
||||
tmux select-window -t work:logs
|
||||
tmux select-window -t work:2 # by index
|
||||
|
||||
# Rename a window
|
||||
tmux rename-window -t work:2 "editor"
|
||||
|
||||
# Kill a window
|
||||
tmux kill-window -t work:logs
|
||||
|
||||
# Move window to a new index
|
||||
tmux move-window -s work:3 -t work:1
|
||||
|
||||
# From inside tmux:
|
||||
# Prefix + c — new window
|
||||
# Prefix + , — rename window
|
||||
# Prefix + & — kill window
|
||||
# Prefix + n/p — next/previous window
|
||||
# Prefix + 0-9 — switch to window by number
|
||||
```
|
||||
|
||||
### Pane Management
|
||||
|
||||
```bash
|
||||
# Split pane vertically (left/right)
|
||||
tmux split-window -h -t work:1
|
||||
|
||||
# Split pane horizontally (top/bottom)
|
||||
tmux split-window -v -t work:1
|
||||
|
||||
# Split and run a command
|
||||
tmux split-window -h -t work:1 "tail -f /var/log/syslog"
|
||||
|
||||
# Select a pane by index
|
||||
tmux select-pane -t work:1.0
|
||||
|
||||
# Resize panes
|
||||
tmux resize-pane -t work:1.0 -R 20 # expand right by 20 cols
|
||||
tmux resize-pane -t work:1.0 -D 10 # shrink down by 10 rows
|
||||
tmux resize-pane -Z # toggle zoom (fullscreen)
|
||||
|
||||
# Swap panes
|
||||
tmux swap-pane -s work:1.0 -t work:1.1
|
||||
|
||||
# Kill a pane
|
||||
tmux kill-pane -t work:1.1
|
||||
|
||||
# From inside tmux:
|
||||
# Prefix + % — split vertical
|
||||
# Prefix + " — split horizontal
|
||||
# Prefix + arrow — navigate panes
|
||||
# Prefix + z — zoom/unzoom current pane
|
||||
# Prefix + x — kill pane
|
||||
# Prefix + {/} — swap pane with previous/next
|
||||
```
|
||||
|
||||
### Sending Commands to Panes Without Being Attached
|
||||
|
||||
```bash
|
||||
# Send a command to a specific pane and press Enter
|
||||
tmux send-keys -t work:1.0 "ls -la" Enter
|
||||
|
||||
# Run a command in a background pane without attaching
|
||||
tmux send-keys -t work:editor "vim src/main.py" Enter
|
||||
|
||||
# Send Ctrl+C to stop a running process
|
||||
tmux send-keys -t work:1.0 C-c
|
||||
|
||||
# Send text without pressing Enter (useful for pre-filling prompts)
|
||||
tmux send-keys -t work:1.0 "git commit -m '"
|
||||
|
||||
# Clear a pane
|
||||
tmux send-keys -t work:1.0 "clear" Enter
|
||||
|
||||
# Check what's in a pane (capture its output)
|
||||
tmux capture-pane -t work:1.0 -p
|
||||
tmux capture-pane -t work:1.0 -p | grep "ERROR"
|
||||
```
|
||||
|
||||
### Scripting a Full Workspace Layout
|
||||
|
||||
This is the most powerful pattern: create a fully configured multi-pane workspace from a single script.
|
||||
|
||||
```bash
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SESSION="dev"
|
||||
|
||||
# Bail if session already exists
|
||||
tmux has-session -t "$SESSION" 2>/dev/null && {
|
||||
echo "Session $SESSION already exists. Attaching..."
|
||||
tmux attach -t "$SESSION"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Create session with first window
|
||||
tmux new-session -d -s "$SESSION" -n "editor" -x 220 -y 50
|
||||
|
||||
# Window 1: editor + test runner side by side
|
||||
tmux send-keys -t "$SESSION:editor" "vim ." Enter
|
||||
tmux split-window -h -t "$SESSION:editor"
|
||||
tmux send-keys -t "$SESSION:editor.1" "npm test -- --watch" Enter
|
||||
tmux select-pane -t "$SESSION:editor.0"
|
||||
|
||||
# Window 2: server logs
|
||||
tmux new-window -t "$SESSION" -n "server"
|
||||
tmux send-keys -t "$SESSION:server" "docker compose up" Enter
|
||||
tmux split-window -v -t "$SESSION:server"
|
||||
tmux send-keys -t "$SESSION:server.1" "tail -f logs/app.log" Enter
|
||||
|
||||
# Window 3: general shell
|
||||
tmux new-window -t "$SESSION" -n "shell"
|
||||
|
||||
# Focus first window
|
||||
tmux select-window -t "$SESSION:editor"
|
||||
|
||||
# Attach
|
||||
tmux attach -t "$SESSION"
|
||||
```
|
||||
|
||||
### Configuration (`~/.tmux.conf`)
|
||||
|
||||
```bash
|
||||
# Change prefix to Ctrl-a (screen-style)
|
||||
unbind C-b
|
||||
set -g prefix C-a
|
||||
bind C-a send-prefix
|
||||
|
||||
# Enable mouse support
|
||||
set -g mouse on
|
||||
|
||||
# Start window/pane numbering at 1
|
||||
set -g base-index 1
|
||||
setw -g pane-base-index 1
|
||||
|
||||
# Renumber windows when one is closed
|
||||
set -g renumber-windows on
|
||||
|
||||
# Increase scrollback buffer
|
||||
set -g history-limit 50000
|
||||
|
||||
# Use vi keys in copy mode
|
||||
setw -g mode-keys vi
|
||||
|
||||
# Faster key repetition
|
||||
set -s escape-time 0
|
||||
|
||||
# Reload config without restarting
|
||||
bind r source-file ~/.tmux.conf \; display "Config reloaded"
|
||||
|
||||
# Intuitive splits: | and -
|
||||
bind | split-window -h -c "#{pane_current_path}"
|
||||
bind - split-window -v -c "#{pane_current_path}"
|
||||
|
||||
# New windows open in current directory
|
||||
bind c new-window -c "#{pane_current_path}"
|
||||
|
||||
# Status bar
|
||||
set -g status-right "#{session_name} | %H:%M %d-%b"
|
||||
set -g status-interval 5
|
||||
```
|
||||
|
||||
### Copy Mode and Scrollback
|
||||
|
||||
```bash
|
||||
# Enter copy mode (scroll up through output)
|
||||
# Prefix + [
|
||||
|
||||
# In vi mode:
|
||||
# / to search forward, ? to search backward
|
||||
# Space to start selection, Enter to copy
|
||||
# q to exit copy mode
|
||||
|
||||
# Paste the most recent buffer
|
||||
# Prefix + ]
|
||||
|
||||
# List paste buffers
|
||||
tmux list-buffers
|
||||
|
||||
# Show the most recent buffer
|
||||
tmux show-buffer
|
||||
|
||||
# Save buffer to a file
|
||||
tmux save-buffer /tmp/tmux-output.txt
|
||||
|
||||
# Load a file into a buffer
|
||||
tmux load-buffer /tmp/data.txt
|
||||
|
||||
# Pipe pane output to a command
|
||||
tmux pipe-pane -t work:1.0 "cat >> ~/session.log"
|
||||
```
|
||||
|
||||
### Practical Automation Patterns
|
||||
|
||||
```bash
|
||||
# Idempotent session: create or attach
|
||||
ensure_session() {
|
||||
local name="$1"
|
||||
tmux has-session -t "$name" 2>/dev/null \
|
||||
|| tmux new-session -d -s "$name"
|
||||
tmux attach -t "$name"
|
||||
}
|
||||
|
||||
# Run a command in a new background window and tail its output
|
||||
run_bg() {
|
||||
local session="${1:-main}" cmd="${*:2}"
|
||||
tmux new-window -t "$session" -n "bg-$$"
|
||||
tmux send-keys -t "$session:bg-$$" "$cmd" Enter
|
||||
}
|
||||
|
||||
# Wait for a pane to produce specific output (polling)
|
||||
wait_for_output() {
|
||||
local target="$1" pattern="$2" timeout="${3:-30}"
|
||||
local elapsed=0
|
||||
while (( elapsed < timeout )); do
|
||||
tmux capture-pane -t "$target" -p | grep -q "$pattern" && return 0
|
||||
sleep 1
|
||||
(( elapsed++ ))
|
||||
done
|
||||
return 1
|
||||
}
|
||||
|
||||
# Kill all background windows matching a name prefix
|
||||
kill_bg_windows() {
|
||||
local session="$1" prefix="${2:-bg-}"
|
||||
tmux list-windows -t "$session" -F "#W" \
|
||||
| grep "^${prefix}" \
|
||||
| while read -r win; do
|
||||
tmux kill-window -t "${session}:${win}"
|
||||
done
|
||||
}
|
||||
```
|
||||
|
||||
### Remote and SSH Workflows
|
||||
|
||||
```bash
|
||||
# SSH and immediately attach to an existing session
|
||||
ssh user@host -t "tmux attach -t work || tmux new-session -s work"
|
||||
|
||||
# Run a command on remote host inside a tmux session (fire and forget)
|
||||
ssh user@host "tmux new-session -d -s deploy 'bash /opt/deploy.sh'"
|
||||
|
||||
# Watch the remote session output from another terminal
|
||||
ssh user@host -t "tmux attach -t deploy -r" # read-only attach
|
||||
|
||||
# Pair programming: share a session (both users attach to the same session)
|
||||
# User 1:
|
||||
tmux new-session -s shared
|
||||
# User 2 (same server):
|
||||
tmux attach -t shared
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- Always name sessions (`-s name`) in scripts — unnamed sessions are hard to target reliably
|
||||
- Use `tmux has-session -t name 2>/dev/null` before creating to make scripts idempotent
|
||||
- Set `-x` and `-y` when creating detached sessions to give panes a proper size for commands that check terminal dimensions
|
||||
- Use `send-keys ... Enter` for automation rather than piping stdin — it works even when the target pane is running an interactive program
|
||||
- Keep `~/.tmux.conf` in version control for reproducibility across machines
|
||||
- Prefer `bind -n` for bindings that don't need the prefix, but only for keys that don't conflict with application shortcuts
|
||||
|
||||
## Security & Safety Notes
|
||||
|
||||
- `send-keys` executes commands in a pane without confirmation — verify the target (`-t session:window.pane`) before use in scripts to avoid sending keystrokes to the wrong pane
|
||||
- Read-only attach (`-r`) is appropriate when sharing sessions with others to prevent accidental input
|
||||
- Avoid storing secrets in tmux window/pane titles or environment variables exported into sessions on shared machines
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
- **Problem:** `tmux` commands from a script fail with "no server running"
|
||||
**Solution:** Start the server first with `tmux start-server`, or create a detached session before running other commands.
|
||||
|
||||
- **Problem:** Pane size is 0x0 when creating a detached session
|
||||
**Solution:** Pass explicit dimensions: `tmux new-session -d -s name -x 200 -y 50`.
|
||||
|
||||
- **Problem:** `send-keys` types the text but doesn't run the command
|
||||
**Solution:** Ensure you pass `Enter` (capital E) as a second argument: `tmux send-keys -t target "cmd" Enter`.
|
||||
|
||||
- **Problem:** Script creates a duplicate session each run
|
||||
**Solution:** Guard with `tmux has-session -t name 2>/dev/null || tmux new-session -d -s name`.
|
||||
|
||||
- **Problem:** Copy-mode selection doesn't work as expected
|
||||
**Solution:** Confirm `mode-keys vi` or `mode-keys emacs` is set to match your preference in `~/.tmux.conf`.
|
||||
|
||||
## Related Skills
|
||||
|
||||
- `@bash-pro` — Writing the shell scripts that orchestrate tmux sessions
|
||||
- `@bash-linux` — General Linux terminal patterns used inside tmux panes
|
||||
- `@ssh` — Combining tmux with SSH for persistent remote workflows
|
||||
@@ -0,0 +1,639 @@
|
||||
---
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/transformers-js"
|
||||
name: transformers-js
|
||||
description: Run Hugging Face models in JavaScript or TypeScript with Transformers.js in Node.js or the browser.
|
||||
license: Apache-2.0
|
||||
risk: unknown
|
||||
metadata:
|
||||
author: huggingface
|
||||
version: "3.8.1"
|
||||
category: machine-learning
|
||||
repository: https://github.com/huggingface/transformers.js
|
||||
compatibility: Requires Node.js 18+ or modern browser with ES modules support. WebGPU support requires compatible browser/environment. Internet access needed for downloading models from Hugging Face Hub (optional if using local models).
|
||||
---
|
||||
|
||||
# Transformers.js - Machine Learning for JavaScript
|
||||
|
||||
Transformers.js enables running state-of-the-art machine learning models directly in JavaScript, both in browsers and Node.js environments, with no server required.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when you need to:
|
||||
- Run ML models for text analysis, generation, or translation in JavaScript
|
||||
- Perform image classification, object detection, or segmentation
|
||||
- Implement speech recognition or audio processing
|
||||
- Build multimodal AI applications (text-to-image, image-to-text, etc.)
|
||||
- Run models client-side in the browser without a backend
|
||||
|
||||
## Installation
|
||||
|
||||
### NPM Installation
|
||||
```bash
|
||||
npm install @huggingface/transformers
|
||||
```
|
||||
|
||||
### Browser Usage (CDN)
|
||||
```javascript
|
||||
<script type="module">
|
||||
import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers';
|
||||
</script>
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Pipeline API
|
||||
The pipeline API is the easiest way to use models. It groups together preprocessing, model inference, and postprocessing:
|
||||
|
||||
```javascript
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
|
||||
// Create a pipeline for a specific task
|
||||
const pipe = await pipeline('sentiment-analysis');
|
||||
|
||||
// Use the pipeline
|
||||
const result = await pipe('I love transformers!');
|
||||
// Output: [{ label: 'POSITIVE', score: 0.999817686 }]
|
||||
|
||||
// IMPORTANT: Always dispose when done to free memory
|
||||
await classifier.dispose();
|
||||
```
|
||||
|
||||
**⚠️ Memory Management:** All pipelines must be disposed with `pipe.dispose()` when finished to prevent memory leaks. See examples in [Code Examples](./references/EXAMPLES.md) for cleanup patterns across different environments.
|
||||
|
||||
### 2. Model Selection
|
||||
You can specify a custom model as the second argument:
|
||||
|
||||
```javascript
|
||||
const pipe = await pipeline(
|
||||
'sentiment-analysis',
|
||||
'Xenova/bert-base-multilingual-uncased-sentiment'
|
||||
);
|
||||
```
|
||||
|
||||
**Finding Models:**
|
||||
|
||||
Browse available Transformers.js models on Hugging Face Hub:
|
||||
- **All models**: https://huggingface.co/models?library=transformers.js&sort=trending
|
||||
- **By task**: Add `pipeline_tag` parameter
|
||||
- Text generation: https://huggingface.co/models?pipeline_tag=text-generation&library=transformers.js&sort=trending
|
||||
- Image classification: https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js&sort=trending
|
||||
- Speech recognition: https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&library=transformers.js&sort=trending
|
||||
|
||||
**Tip:** Filter by task type, sort by trending/downloads, and check model cards for performance metrics and usage examples.
|
||||
|
||||
### 3. Device Selection
|
||||
Choose where to run the model:
|
||||
|
||||
```javascript
|
||||
// Run on CPU (default for WASM)
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id');
|
||||
|
||||
// Run on GPU (WebGPU - experimental)
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
device: 'webgpu',
|
||||
});
|
||||
```
|
||||
|
||||
### 4. Quantization Options
|
||||
Control model precision vs. performance:
|
||||
|
||||
```javascript
|
||||
// Use quantized model (faster, smaller)
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
dtype: 'q4', // Options: 'fp32', 'fp16', 'q8', 'q4'
|
||||
});
|
||||
```
|
||||
|
||||
## Supported Tasks
|
||||
|
||||
**Note:** All examples below show basic usage.
|
||||
|
||||
### Natural Language Processing
|
||||
|
||||
#### Text Classification
|
||||
```javascript
|
||||
const classifier = await pipeline('text-classification');
|
||||
const result = await classifier('This movie was amazing!');
|
||||
```
|
||||
|
||||
#### Named Entity Recognition (NER)
|
||||
```javascript
|
||||
const ner = await pipeline('token-classification');
|
||||
const entities = await ner('My name is John and I live in New York.');
|
||||
```
|
||||
|
||||
#### Question Answering
|
||||
```javascript
|
||||
const qa = await pipeline('question-answering');
|
||||
const answer = await qa({
|
||||
question: 'What is the capital of France?',
|
||||
context: 'Paris is the capital and largest city of France.'
|
||||
});
|
||||
```
|
||||
|
||||
#### Text Generation
|
||||
```javascript
|
||||
const generator = await pipeline('text-generation', 'onnx-community/gemma-3-270m-it-ONNX');
|
||||
const text = await generator('Once upon a time', {
|
||||
max_new_tokens: 100,
|
||||
temperature: 0.7
|
||||
});
|
||||
```
|
||||
|
||||
**For streaming and chat:** See **[Text Generation Guide](./references/TEXT_GENERATION.md)** for:
|
||||
- Streaming token-by-token output with `TextStreamer`
|
||||
- Chat/conversation format with system/user/assistant roles
|
||||
- Generation parameters (temperature, top_k, top_p)
|
||||
- Browser and Node.js examples
|
||||
- React components and API endpoints
|
||||
|
||||
#### Translation
|
||||
```javascript
|
||||
const translator = await pipeline('translation', 'Xenova/nllb-200-distilled-600M');
|
||||
const output = await translator('Hello, how are you?', {
|
||||
src_lang: 'eng_Latn',
|
||||
tgt_lang: 'fra_Latn'
|
||||
});
|
||||
```
|
||||
|
||||
#### Summarization
|
||||
```javascript
|
||||
const summarizer = await pipeline('summarization');
|
||||
const summary = await summarizer(longText, {
|
||||
max_length: 100,
|
||||
min_length: 30
|
||||
});
|
||||
```
|
||||
|
||||
#### Zero-Shot Classification
|
||||
```javascript
|
||||
const classifier = await pipeline('zero-shot-classification');
|
||||
const result = await classifier('This is a story about sports.', ['politics', 'sports', 'technology']);
|
||||
```
|
||||
|
||||
### Computer Vision
|
||||
|
||||
#### Image Classification
|
||||
```javascript
|
||||
const classifier = await pipeline('image-classification');
|
||||
const result = await classifier('https://example.com/image.jpg');
|
||||
// Or with local file
|
||||
const result = await classifier(imageUrl);
|
||||
```
|
||||
|
||||
#### Object Detection
|
||||
```javascript
|
||||
const detector = await pipeline('object-detection');
|
||||
const objects = await detector('https://example.com/image.jpg');
|
||||
// Returns: [{ label: 'person', score: 0.95, box: { xmin, ymin, xmax, ymax } }, ...]
|
||||
```
|
||||
|
||||
#### Image Segmentation
|
||||
```javascript
|
||||
const segmenter = await pipeline('image-segmentation');
|
||||
const segments = await segmenter('https://example.com/image.jpg');
|
||||
```
|
||||
|
||||
#### Depth Estimation
|
||||
```javascript
|
||||
const depthEstimator = await pipeline('depth-estimation');
|
||||
const depth = await depthEstimator('https://example.com/image.jpg');
|
||||
```
|
||||
|
||||
#### Zero-Shot Image Classification
|
||||
```javascript
|
||||
const classifier = await pipeline('zero-shot-image-classification');
|
||||
const result = await classifier('image.jpg', ['cat', 'dog', 'bird']);
|
||||
```
|
||||
|
||||
### Audio Processing
|
||||
|
||||
#### Automatic Speech Recognition
|
||||
```javascript
|
||||
const transcriber = await pipeline('automatic-speech-recognition');
|
||||
const result = await transcriber('audio.wav');
|
||||
// Returns: { text: 'transcribed text here' }
|
||||
```
|
||||
|
||||
#### Audio Classification
|
||||
```javascript
|
||||
const classifier = await pipeline('audio-classification');
|
||||
const result = await classifier('audio.wav');
|
||||
```
|
||||
|
||||
#### Text-to-Speech
|
||||
```javascript
|
||||
const synthesizer = await pipeline('text-to-speech', 'Xenova/speecht5_tts');
|
||||
const audio = await synthesizer('Hello, this is a test.', {
|
||||
speaker_embeddings: speakerEmbeddings
|
||||
});
|
||||
```
|
||||
|
||||
### Multimodal
|
||||
|
||||
#### Image-to-Text (Image Captioning)
|
||||
```javascript
|
||||
const captioner = await pipeline('image-to-text');
|
||||
const caption = await captioner('image.jpg');
|
||||
```
|
||||
|
||||
#### Document Question Answering
|
||||
```javascript
|
||||
const docQA = await pipeline('document-question-answering');
|
||||
const answer = await docQA('document-image.jpg', 'What is the total amount?');
|
||||
```
|
||||
|
||||
#### Zero-Shot Object Detection
|
||||
```javascript
|
||||
const detector = await pipeline('zero-shot-object-detection');
|
||||
const objects = await detector('image.jpg', ['person', 'car', 'tree']);
|
||||
```
|
||||
|
||||
### Feature Extraction (Embeddings)
|
||||
|
||||
```javascript
|
||||
const extractor = await pipeline('feature-extraction');
|
||||
const embeddings = await extractor('This is a sentence to embed.');
|
||||
// Returns: tensor of shape [1, sequence_length, hidden_size]
|
||||
|
||||
// For sentence embeddings (mean pooling)
|
||||
const extractor = await pipeline('feature-extraction', 'onnx-community/all-MiniLM-L6-v2-ONNX');
|
||||
const embeddings = await extractor('Text to embed', { pooling: 'mean', normalize: true });
|
||||
```
|
||||
|
||||
## Finding and Choosing Models
|
||||
|
||||
### Browsing the Hugging Face Hub
|
||||
|
||||
Discover compatible Transformers.js models on Hugging Face Hub:
|
||||
|
||||
**Base URL (all models):**
|
||||
```
|
||||
https://huggingface.co/models?library=transformers.js&sort=trending
|
||||
```
|
||||
|
||||
**Filter by task** using the `pipeline_tag` parameter:
|
||||
|
||||
| Task | URL |
|
||||
|------|-----|
|
||||
| **Text Generation** | https://huggingface.co/models?pipeline_tag=text-generation&library=transformers.js&sort=trending |
|
||||
| **Text Classification** | https://huggingface.co/models?pipeline_tag=text-classification&library=transformers.js&sort=trending |
|
||||
| **Translation** | https://huggingface.co/models?pipeline_tag=translation&library=transformers.js&sort=trending |
|
||||
| **Summarization** | https://huggingface.co/models?pipeline_tag=summarization&library=transformers.js&sort=trending |
|
||||
| **Question Answering** | https://huggingface.co/models?pipeline_tag=question-answering&library=transformers.js&sort=trending |
|
||||
| **Image Classification** | https://huggingface.co/models?pipeline_tag=image-classification&library=transformers.js&sort=trending |
|
||||
| **Object Detection** | https://huggingface.co/models?pipeline_tag=object-detection&library=transformers.js&sort=trending |
|
||||
| **Image Segmentation** | https://huggingface.co/models?pipeline_tag=image-segmentation&library=transformers.js&sort=trending |
|
||||
| **Speech Recognition** | https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&library=transformers.js&sort=trending |
|
||||
| **Audio Classification** | https://huggingface.co/models?pipeline_tag=audio-classification&library=transformers.js&sort=trending |
|
||||
| **Image-to-Text** | https://huggingface.co/models?pipeline_tag=image-to-text&library=transformers.js&sort=trending |
|
||||
| **Feature Extraction** | https://huggingface.co/models?pipeline_tag=feature-extraction&library=transformers.js&sort=trending |
|
||||
| **Zero-Shot Classification** | https://huggingface.co/models?pipeline_tag=zero-shot-classification&library=transformers.js&sort=trending |
|
||||
|
||||
**Sort options:**
|
||||
- `&sort=trending` - Most popular recently
|
||||
- `&sort=downloads` - Most downloaded overall
|
||||
- `&sort=likes` - Most liked by community
|
||||
- `&sort=modified` - Recently updated
|
||||
|
||||
### Choosing the Right Model
|
||||
|
||||
Consider these factors when selecting a model:
|
||||
|
||||
**1. Model Size**
|
||||
- **Small (< 100MB)**: Fast, suitable for browsers, limited accuracy
|
||||
- **Medium (100MB - 500MB)**: Balanced performance, good for most use cases
|
||||
- **Large (> 500MB)**: High accuracy, slower, better for Node.js or powerful devices
|
||||
|
||||
**2. Quantization**
|
||||
Models are often available in different quantization levels:
|
||||
- `fp32` - Full precision (largest, most accurate)
|
||||
- `fp16` - Half precision (smaller, still accurate)
|
||||
- `q8` - 8-bit quantized (much smaller, slight accuracy loss)
|
||||
- `q4` - 4-bit quantized (smallest, noticeable accuracy loss)
|
||||
|
||||
**3. Task Compatibility**
|
||||
Check the model card for:
|
||||
- Supported tasks (some models support multiple tasks)
|
||||
- Input/output formats
|
||||
- Language support (multilingual vs. English-only)
|
||||
- License restrictions
|
||||
|
||||
**4. Performance Metrics**
|
||||
Model cards typically show:
|
||||
- Accuracy scores
|
||||
- Benchmark results
|
||||
- Inference speed
|
||||
- Memory requirements
|
||||
|
||||
### Example: Finding a Text Generation Model
|
||||
|
||||
```javascript
|
||||
// 1. Visit: https://huggingface.co/models?pipeline_tag=text-generation&library=transformers.js&sort=trending
|
||||
|
||||
// 2. Browse and select a model (e.g., onnx-community/gemma-3-270m-it-ONNX)
|
||||
|
||||
// 3. Check model card for:
|
||||
// - Model size: ~270M parameters
|
||||
// - Quantization: q4 available
|
||||
// - Language: English
|
||||
// - Use case: Instruction-following chat
|
||||
|
||||
// 4. Use the model:
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
|
||||
const generator = await pipeline(
|
||||
'text-generation',
|
||||
'onnx-community/gemma-3-270m-it-ONNX',
|
||||
{ dtype: 'q4' } // Use quantized version for faster inference
|
||||
);
|
||||
|
||||
const output = await generator('Explain quantum computing in simple terms.', {
|
||||
max_new_tokens: 100
|
||||
});
|
||||
|
||||
await generator.dispose();
|
||||
```
|
||||
|
||||
### Tips for Model Selection
|
||||
|
||||
1. **Start Small**: Test with a smaller model first, then upgrade if needed
|
||||
2. **Check ONNX Support**: Ensure the model has ONNX files (look for `onnx` folder in model repo)
|
||||
3. **Read Model Cards**: Model cards contain usage examples, limitations, and benchmarks
|
||||
4. **Test Locally**: Benchmark inference speed and memory usage in your environment
|
||||
5. **Community Models**: Look for models by `Xenova` (Transformers.js maintainer) or `onnx-community`
|
||||
6. **Version Pin**: Use specific git commits in production for stability:
|
||||
```javascript
|
||||
const pipe = await pipeline('task', 'model-id', { revision: 'abc123' });
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Environment Configuration (`env`)
|
||||
|
||||
The `env` object provides comprehensive control over Transformers.js execution, caching, and model loading.
|
||||
|
||||
**Quick Overview:**
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// View version
|
||||
console.log(env.version); // e.g., '3.8.1'
|
||||
|
||||
// Common settings
|
||||
env.allowRemoteModels = true; // Load from Hugging Face Hub
|
||||
env.allowLocalModels = false; // Load from file system
|
||||
env.localModelPath = '/models/'; // Local model directory
|
||||
env.useFSCache = true; // Cache models on disk (Node.js)
|
||||
env.useBrowserCache = true; // Cache models in browser
|
||||
env.cacheDir = './.cache'; // Cache directory location
|
||||
```
|
||||
|
||||
**Configuration Patterns:**
|
||||
|
||||
```javascript
|
||||
// Development: Fast iteration with remote models
|
||||
env.allowRemoteModels = true;
|
||||
env.useFSCache = true;
|
||||
|
||||
// Production: Local models only
|
||||
env.allowRemoteModels = false;
|
||||
env.allowLocalModels = true;
|
||||
env.localModelPath = '/app/models/';
|
||||
|
||||
// Custom CDN
|
||||
env.remoteHost = 'https://cdn.example.com/models';
|
||||
|
||||
// Disable caching (testing)
|
||||
env.useFSCache = false;
|
||||
env.useBrowserCache = false;
|
||||
```
|
||||
|
||||
For complete documentation on all configuration options, caching strategies, cache management, pre-downloading models, and more, see:
|
||||
|
||||
**→ [Configuration Reference](./references/CONFIGURATION.md)**
|
||||
|
||||
### Working with Tensors
|
||||
|
||||
```javascript
|
||||
import { AutoTokenizer, AutoModel } from '@huggingface/transformers';
|
||||
|
||||
// Load tokenizer and model separately for more control
|
||||
const tokenizer = await AutoTokenizer.from_pretrained('bert-base-uncased');
|
||||
const model = await AutoModel.from_pretrained('bert-base-uncased');
|
||||
|
||||
// Tokenize input
|
||||
const inputs = await tokenizer('Hello world!');
|
||||
|
||||
// Run model
|
||||
const outputs = await model(inputs);
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```javascript
|
||||
const classifier = await pipeline('sentiment-analysis');
|
||||
|
||||
// Process multiple texts
|
||||
const results = await classifier([
|
||||
'I love this!',
|
||||
'This is terrible.',
|
||||
'It was okay.'
|
||||
]);
|
||||
```
|
||||
|
||||
## Browser-Specific Considerations
|
||||
|
||||
### WebGPU Usage
|
||||
WebGPU provides GPU acceleration in browsers:
|
||||
|
||||
```javascript
|
||||
const pipe = await pipeline('text-generation', 'onnx-community/gemma-3-270m-it-ONNX', {
|
||||
device: 'webgpu',
|
||||
dtype: 'fp32'
|
||||
});
|
||||
```
|
||||
|
||||
**Note**: WebGPU is experimental. Check browser compatibility and file issues if problems occur.
|
||||
|
||||
### WASM Performance
|
||||
Default browser execution uses WASM:
|
||||
|
||||
```javascript
|
||||
// Optimized for browsers with quantization
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
dtype: 'q8' // or 'q4' for even smaller size
|
||||
});
|
||||
```
|
||||
|
||||
### Progress Tracking & Loading Indicators
|
||||
|
||||
Models can be large (ranging from a few MB to several GB) and consist of multiple files. Track download progress by passing a callback to the `pipeline()` function:
|
||||
|
||||
```javascript
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
|
||||
// Track progress for each file
|
||||
const fileProgress = {};
|
||||
|
||||
function onProgress(info) {
|
||||
console.log(`${info.status}: ${info.file}`);
|
||||
|
||||
if (info.status === 'progress') {
|
||||
fileProgress[info.file] = info.progress;
|
||||
console.log(`${info.file}: ${info.progress.toFixed(1)}%`);
|
||||
}
|
||||
|
||||
if (info.status === 'done') {
|
||||
console.log(`✓ ${info.file} complete`);
|
||||
}
|
||||
}
|
||||
|
||||
// Pass callback to pipeline
|
||||
const classifier = await pipeline('sentiment-analysis', null, {
|
||||
progress_callback: onProgress
|
||||
});
|
||||
```
|
||||
|
||||
**Progress Info Properties:**
|
||||
|
||||
```typescript
|
||||
interface ProgressInfo {
|
||||
status: 'initiate' | 'download' | 'progress' | 'done' | 'ready';
|
||||
name: string; // Model id or path
|
||||
file: string; // File being processed
|
||||
progress?: number; // Percentage (0-100, only for 'progress' status)
|
||||
loaded?: number; // Bytes downloaded (only for 'progress' status)
|
||||
total?: number; // Total bytes (only for 'progress' status)
|
||||
}
|
||||
```
|
||||
|
||||
For complete examples including browser UIs, React components, CLI progress bars, and retry logic, see:
|
||||
|
||||
**→ [Pipeline Options - Progress Callback](./references/PIPELINE_OPTIONS.md#progress-callback)**
|
||||
|
||||
## Error Handling
|
||||
|
||||
```javascript
|
||||
try {
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id');
|
||||
const result = await pipe('text to analyze');
|
||||
} catch (error) {
|
||||
if (error.message.includes('fetch')) {
|
||||
console.error('Model download failed. Check internet connection.');
|
||||
} else if (error.message.includes('ONNX')) {
|
||||
console.error('Model execution failed. Check model compatibility.');
|
||||
} else {
|
||||
console.error('Unknown error:', error);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Reuse Pipelines**: Create pipeline once, reuse for multiple inferences
|
||||
2. **Use Quantization**: Start with `q8` or `q4` for faster inference
|
||||
3. **Batch Processing**: Process multiple inputs together when possible
|
||||
4. **Cache Models**: Models are cached automatically (see **[Caching Reference](./references/CACHE.md)** for details on browser Cache API, Node.js filesystem cache, and custom implementations)
|
||||
5. **WebGPU for Large Models**: Use WebGPU for models that benefit from GPU acceleration
|
||||
6. **Prune Context**: For text generation, limit `max_new_tokens` to avoid memory issues
|
||||
7. **Clean Up Resources**: Call `pipe.dispose()` when done to free memory
|
||||
|
||||
## Memory Management
|
||||
|
||||
**IMPORTANT:** Always call `pipe.dispose()` when finished to prevent memory leaks.
|
||||
|
||||
```javascript
|
||||
const pipe = await pipeline('sentiment-analysis');
|
||||
const result = await pipe('Great product!');
|
||||
await pipe.dispose(); // ✓ Free memory (100MB - several GB per model)
|
||||
```
|
||||
|
||||
**When to dispose:**
|
||||
- Application shutdown or component unmount
|
||||
- Before loading a different model
|
||||
- After batch processing in long-running apps
|
||||
|
||||
Models consume significant memory and hold GPU/CPU resources. Disposal is critical for browser memory limits and server stability.
|
||||
|
||||
For detailed patterns (React cleanup, servers, browser), see **[Code Examples](./references/EXAMPLES.md)**
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Model Not Found
|
||||
- Verify model exists on Hugging Face Hub
|
||||
- Check model name spelling
|
||||
- Ensure model has ONNX files (look for `onnx` folder in model repo)
|
||||
|
||||
### Memory Issues
|
||||
- Use smaller models or quantized versions (`dtype: 'q4'`)
|
||||
- Reduce batch size
|
||||
- Limit sequence length with `max_length`
|
||||
|
||||
### WebGPU Errors
|
||||
- Check browser compatibility (Chrome 113+, Edge 113+)
|
||||
- Try `dtype: 'fp16'` if `fp32` fails
|
||||
- Fall back to WASM if WebGPU unavailable
|
||||
|
||||
## Reference Documentation
|
||||
|
||||
### This Skill
|
||||
- **[Pipeline Options](./references/PIPELINE_OPTIONS.md)** - Configure `pipeline()` with `progress_callback`, `device`, `dtype`, etc.
|
||||
- **[Configuration Reference](./references/CONFIGURATION.md)** - Global `env` configuration for caching and model loading
|
||||
- **[Caching Reference](./references/CACHE.md)** - Browser Cache API, Node.js filesystem cache, and custom cache implementations
|
||||
- **[Text Generation Guide](./references/TEXT_GENERATION.md)** - Streaming, chat format, and generation parameters
|
||||
- **[Model Architectures](./references/MODEL_ARCHITECTURES.md)** - Supported models and selection tips
|
||||
- **[Code Examples](./references/EXAMPLES.md)** - Real-world implementations for different runtimes
|
||||
|
||||
### Official Transformers.js
|
||||
- Official docs: https://huggingface.co/docs/transformers.js
|
||||
- API reference: https://huggingface.co/docs/transformers.js/api/pipelines
|
||||
- Model hub: https://huggingface.co/models?library=transformers.js
|
||||
- GitHub: https://github.com/huggingface/transformers.js
|
||||
- Examples: https://github.com/huggingface/transformers.js/tree/main/examples
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always Dispose Pipelines**: Call `pipe.dispose()` when done - critical for preventing memory leaks
|
||||
2. **Start with Pipelines**: Use the pipeline API unless you need fine-grained control
|
||||
3. **Test Locally First**: Test models with small inputs before deploying
|
||||
4. **Monitor Model Sizes**: Be aware of model download sizes for web applications
|
||||
5. **Handle Loading States**: Show progress indicators for better UX
|
||||
6. **Version Pin**: Pin specific model versions for production stability
|
||||
7. **Error Boundaries**: Always wrap pipeline calls in try-catch blocks
|
||||
8. **Progressive Enhancement**: Provide fallbacks for unsupported browsers
|
||||
9. **Reuse Models**: Load once, use many times - don't recreate pipelines unnecessarily
|
||||
10. **Graceful Shutdown**: Dispose models on SIGTERM/SIGINT in servers
|
||||
|
||||
## Quick Reference: Task IDs
|
||||
|
||||
| Task | Task ID |
|
||||
|------|---------|
|
||||
| Text classification | `text-classification` or `sentiment-analysis` |
|
||||
| Token classification | `token-classification` or `ner` |
|
||||
| Question answering | `question-answering` |
|
||||
| Fill mask | `fill-mask` |
|
||||
| Summarization | `summarization` |
|
||||
| Translation | `translation` |
|
||||
| Text generation | `text-generation` |
|
||||
| Text-to-text generation | `text2text-generation` |
|
||||
| Zero-shot classification | `zero-shot-classification` |
|
||||
| Image classification | `image-classification` |
|
||||
| Image segmentation | `image-segmentation` |
|
||||
| Object detection | `object-detection` |
|
||||
| Depth estimation | `depth-estimation` |
|
||||
| Image-to-image | `image-to-image` |
|
||||
| Zero-shot image classification | `zero-shot-image-classification` |
|
||||
| Zero-shot object detection | `zero-shot-object-detection` |
|
||||
| Automatic speech recognition | `automatic-speech-recognition` |
|
||||
| Audio classification | `audio-classification` |
|
||||
| Text-to-speech | `text-to-speech` or `text-to-audio` |
|
||||
| Image-to-text | `image-to-text` |
|
||||
| Document question answering | `document-question-answering` |
|
||||
| Feature extraction | `feature-extraction` |
|
||||
| Sentence similarity | `sentence-similarity` |
|
||||
|
||||
---
|
||||
|
||||
This skill enables you to integrate state-of-the-art machine learning capabilities directly into JavaScript applications without requiring separate ML servers or Python environments.
|
||||
@@ -0,0 +1,339 @@
|
||||
# Caching Reference
|
||||
|
||||
Complete guide to caching strategies for Transformers.js models across different environments.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Overview](#overview)
|
||||
2. [Browser Caching](#browser-caching)
|
||||
3. [Node.js Caching](#nodejs-caching)
|
||||
4. [Custom Cache Implementation](#custom-cache-implementation)
|
||||
5. [Cache Configuration](#cache-configuration)
|
||||
|
||||
## Overview
|
||||
|
||||
Transformers.js models can be large (from a few MB to several GB), so caching is critical for performance. The caching strategy differs based on the environment:
|
||||
|
||||
- **Browser**: Uses the Cache API (browser cache storage)
|
||||
- **Node.js**: Uses filesystem cache in `~/.cache/huggingface/`
|
||||
- **Custom**: Implement your own cache (database, cloud storage, etc.)
|
||||
|
||||
### Default Behavior
|
||||
|
||||
```javascript
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
|
||||
// First load: downloads model
|
||||
const pipe = await pipeline('sentiment-analysis');
|
||||
|
||||
// Subsequent loads: uses cached model
|
||||
const pipe2 = await pipeline('sentiment-analysis'); // Fast!
|
||||
```
|
||||
|
||||
Caching is **automatic** and enabled by default. Models are cached after the first download.
|
||||
|
||||
## Browser Caching
|
||||
|
||||
### Using the Cache API
|
||||
|
||||
In browser environments, Transformers.js uses the [Cache API](https://developer.mozilla.org/en-US/docs/Web/API/Cache) to store models:
|
||||
|
||||
```javascript
|
||||
import { env, pipeline } from '@huggingface/transformers';
|
||||
|
||||
// Browser cache is enabled by default
|
||||
console.log(env.useBrowserCache); // true
|
||||
|
||||
// Load model (cached automatically)
|
||||
const classifier = await pipeline('sentiment-analysis');
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
|
||||
1. Model files are downloaded from Hugging Face Hub
|
||||
2. Files are stored in the browser's Cache Storage
|
||||
3. Subsequent loads retrieve from cache (no network request)
|
||||
4. Cache persists across page reloads and browser sessions
|
||||
|
||||
### Cache Location
|
||||
|
||||
Browser caches are stored in:
|
||||
- **Chrome/Edge**: `Cache Storage` in DevTools → Application tab → Cache storage
|
||||
- **Firefox**: `about:cache` → Storage
|
||||
- **Safari**: Web Inspector → Storage tab
|
||||
|
||||
### Disable Browser Cache
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Disable browser caching (not recommended)
|
||||
env.useBrowserCache = false;
|
||||
|
||||
// Models will be re-downloaded on every page load
|
||||
```
|
||||
|
||||
**Use case:** Testing, development, or debugging cache issues.
|
||||
|
||||
### Browser Storage Limits
|
||||
|
||||
Browsers impose storage quotas:
|
||||
|
||||
- **Chrome**: ~60% of available disk space (but can evict data)
|
||||
- **Firefox**: ~50% of available disk space
|
||||
- **Safari**: ~1GB per origin (prompt for more)
|
||||
|
||||
**Tip:** Monitor storage usage with the [Storage API](https://developer.mozilla.org/en-US/docs/Web/API/Storage_API):
|
||||
|
||||
```javascript
|
||||
if ('storage' in navigator && 'estimate' in navigator.storage) {
|
||||
const estimate = await navigator.storage.estimate();
|
||||
const percentUsed = (estimate.usage / estimate.quota) * 100;
|
||||
console.log(`Storage: ${percentUsed.toFixed(2)}% used`);
|
||||
console.log(`Available: ${((estimate.quota - estimate.usage) / 1024 / 1024).toFixed(2)} MB`);
|
||||
}
|
||||
```
|
||||
|
||||
## Node.js Caching
|
||||
|
||||
### Filesystem Cache
|
||||
|
||||
In Node.js, models are cached to the filesystem:
|
||||
|
||||
```javascript
|
||||
import { env, pipeline } from '@huggingface/transformers';
|
||||
|
||||
// Default cache directory (Node.js)
|
||||
console.log(env.cacheDir); // './.cache' (relative to current directory)
|
||||
|
||||
// Filesystem cache is enabled by default
|
||||
console.log(env.useFSCache); // true
|
||||
|
||||
// Load model (cached to disk)
|
||||
const classifier = await pipeline('sentiment-analysis');
|
||||
```
|
||||
|
||||
### Default Cache Location
|
||||
|
||||
**Default behavior:**
|
||||
- Cache directory: `./.cache` (relative to where Node.js process runs)
|
||||
- Full default path: `~/.cache/huggingface/` when using Hugging Face tools
|
||||
|
||||
**Note:** The statement "Models are cached automatically in `~/.cache/huggingface/`" from performance tips is specific to Hugging Face's Python tooling convention. In Transformers.js for Node.js, the default is `./.cache` unless configured otherwise.
|
||||
|
||||
### Custom Cache Directory
|
||||
|
||||
```javascript
|
||||
import { env, pipeline } from '@huggingface/transformers';
|
||||
|
||||
// Set custom cache directory
|
||||
env.cacheDir = '/var/cache/transformers';
|
||||
|
||||
// Or use environment variable (Node.js convention)
|
||||
env.cacheDir = process.env.HF_HOME || '~/.cache/huggingface';
|
||||
|
||||
// Now load model
|
||||
const classifier = await pipeline('sentiment-analysis');
|
||||
// Cached to: /var/cache/transformers/models--Xenova--distilbert-base-uncased-finetuned-sst-2-english/
|
||||
```
|
||||
|
||||
**Pattern:** `models--{organization}--{model-name}/`
|
||||
|
||||
### Disable Filesystem Cache
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Disable filesystem caching (not recommended)
|
||||
env.useFSCache = false;
|
||||
|
||||
// Models will be re-downloaded on every load
|
||||
```
|
||||
|
||||
**Use case:** Testing, CI/CD environments, or containers with ephemeral storage.
|
||||
|
||||
## Custom Cache Implementation
|
||||
|
||||
Implement your own cache for specialized storage backends.
|
||||
|
||||
### Custom Cache Interface
|
||||
|
||||
```typescript
|
||||
interface CacheInterface {
|
||||
/**
|
||||
* Check if a URL is cached
|
||||
*/
|
||||
match(url: string): Promise<Response | undefined>;
|
||||
|
||||
/**
|
||||
* Store a URL and its response
|
||||
*/
|
||||
put(url: string, response: Response): Promise<void>;
|
||||
}
|
||||
```
|
||||
|
||||
### Example: Cloud Storage Cache (S3)
|
||||
|
||||
```javascript
|
||||
import { env, pipeline } from '@huggingface/transformers';
|
||||
import { S3Client, GetObjectCommand, PutObjectCommand } from '@aws-sdk/client-s3';
|
||||
import { Readable } from 'stream';
|
||||
|
||||
class S3Cache {
|
||||
constructor(bucket, region = 'us-east-1') {
|
||||
this.bucket = bucket;
|
||||
this.s3 = new S3Client({ region });
|
||||
}
|
||||
|
||||
async match(url) {
|
||||
const key = this.urlToKey(url);
|
||||
|
||||
try {
|
||||
const command = new GetObjectCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key
|
||||
});
|
||||
const response = await this.s3.send(command);
|
||||
|
||||
// Convert stream to buffer
|
||||
const chunks = [];
|
||||
for await (const chunk of response.Body) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
const body = Buffer.concat(chunks);
|
||||
|
||||
return new Response(body, {
|
||||
status: 200,
|
||||
headers: JSON.parse(response.Metadata.headers || '{}')
|
||||
});
|
||||
} catch (error) {
|
||||
if (error.name === 'NoSuchKey') return undefined;
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async put(url, response) {
|
||||
const key = this.urlToKey(url);
|
||||
const clonedResponse = response.clone();
|
||||
const body = Buffer.from(await clonedResponse.arrayBuffer());
|
||||
const headers = JSON.stringify(Object.fromEntries(response.headers.entries()));
|
||||
|
||||
const command = new PutObjectCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key,
|
||||
Body: body,
|
||||
Metadata: { headers }
|
||||
});
|
||||
|
||||
await this.s3.send(command);
|
||||
}
|
||||
|
||||
urlToKey(url) {
|
||||
// Convert URL to S3 key (remove protocol, replace slashes)
|
||||
return url.replace(/^https?:\/\//, '').replace(/\//g, '_');
|
||||
}
|
||||
}
|
||||
|
||||
// Configure S3 cache
|
||||
env.useCustomCache = true;
|
||||
env.customCache = new S3Cache('my-transformers-cache', 'us-east-1');
|
||||
env.useFSCache = false;
|
||||
|
||||
// Use S3 cache
|
||||
const classifier = await pipeline('sentiment-analysis');
|
||||
```
|
||||
|
||||
## Cache Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Use environment variables to configure caching:
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Configure cache directory from environment
|
||||
env.cacheDir = process.env.TRANSFORMERS_CACHE || './.cache';
|
||||
|
||||
// Disable caching in CI/CD
|
||||
if (process.env.CI === 'true') {
|
||||
env.useFSCache = false;
|
||||
env.useBrowserCache = false;
|
||||
}
|
||||
|
||||
// Production: use pre-cached models
|
||||
if (process.env.NODE_ENV === 'production') {
|
||||
env.allowRemoteModels = false;
|
||||
env.allowLocalModels = true;
|
||||
env.localModelPath = process.env.MODEL_PATH || '/app/models';
|
||||
}
|
||||
```
|
||||
|
||||
### Configuration Patterns
|
||||
|
||||
#### Development: Enable All Caching
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
env.allowRemoteModels = true;
|
||||
env.useFSCache = true; // Node.js
|
||||
env.useBrowserCache = true; // Browser
|
||||
env.cacheDir = './.cache';
|
||||
```
|
||||
|
||||
#### Production: Local Models Only
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
env.allowRemoteModels = false;
|
||||
env.allowLocalModels = true;
|
||||
env.localModelPath = '/app/models';
|
||||
env.useFSCache = true;
|
||||
```
|
||||
|
||||
#### Testing: Disable Caching
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
env.useFSCache = false;
|
||||
env.useBrowserCache = false;
|
||||
env.allowRemoteModels = true; // Download every time
|
||||
```
|
||||
|
||||
#### Hybrid: Cache + Remote Fallback
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Try local cache first, fall back to remote
|
||||
env.allowRemoteModels = true;
|
||||
env.allowLocalModels = true;
|
||||
env.useFSCache = true;
|
||||
env.localModelPath = './models';
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
Transformers.js provides flexible caching options:
|
||||
|
||||
- **Browser**: Cache API (automatic, persistent)
|
||||
- **Node.js**: Filesystem cache (default `./.cache`, configurable)
|
||||
- **Custom**: Implement your own (database, cloud storage, etc.)
|
||||
|
||||
**Key takeaways:**
|
||||
|
||||
1. Caching is enabled by default and automatic
|
||||
2. Configure cache **before** loading models
|
||||
3. Browser uses Cache API, Node.js uses filesystem
|
||||
4. Custom caches enable advanced storage backends
|
||||
5. Monitor cache size and implement cleanup strategies
|
||||
6. Pre-download models for production deployments
|
||||
|
||||
For more configuration options, see:
|
||||
- [Configuration Reference](./CONFIGURATION.md)
|
||||
- [Pipeline Options](./PIPELINE_OPTIONS.md)
|
||||
@@ -0,0 +1,390 @@
|
||||
# Environment Configuration Reference
|
||||
|
||||
Complete guide to configuring Transformers.js behavior using the `env` object.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Overview](#overview)
|
||||
2. [Remote Model Configuration](#remote-model-configuration)
|
||||
3. [Local Model Configuration](#local-model-configuration)
|
||||
4. [Cache Configuration](#cache-configuration)
|
||||
5. [WASM Configuration](#wasm-configuration)
|
||||
6. [Common Configuration Patterns](#common-configuration-patterns)
|
||||
7. [Environment Best Practices](#environment-best-practices)
|
||||
|
||||
## Overview
|
||||
|
||||
The `env` object provides comprehensive control over Transformers.js execution, caching, and model loading:
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// View current version
|
||||
console.log(env.version); // e.g., '3.8.1'
|
||||
```
|
||||
|
||||
### Available Properties
|
||||
|
||||
```typescript
|
||||
interface TransformersEnvironment {
|
||||
// Version info
|
||||
version: string;
|
||||
|
||||
// Backend configuration
|
||||
backends: {
|
||||
onnx: Partial<ONNXEnv>;
|
||||
};
|
||||
|
||||
// Remote model settings
|
||||
allowRemoteModels: boolean;
|
||||
remoteHost: string;
|
||||
remotePathTemplate: string;
|
||||
|
||||
// Local model settings
|
||||
allowLocalModels: boolean;
|
||||
localModelPath: string;
|
||||
useFS: boolean;
|
||||
|
||||
// Cache settings
|
||||
useBrowserCache: boolean;
|
||||
useFSCache: boolean;
|
||||
cacheDir: string | null;
|
||||
useCustomCache: boolean;
|
||||
customCache: CacheInterface | null;
|
||||
useWasmCache: boolean;
|
||||
cacheKey: string;
|
||||
}
|
||||
```
|
||||
|
||||
## Remote Model Configuration
|
||||
|
||||
Control how models are loaded from remote sources (default: Hugging Face Hub).
|
||||
|
||||
### Disable Remote Loading
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Force local-only mode (no network requests)
|
||||
env.allowRemoteModels = false;
|
||||
```
|
||||
|
||||
**Use case:** Offline applications, security requirements, or air-gapped environments.
|
||||
|
||||
### Custom Model Host
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Use your own CDN or model server
|
||||
env.remoteHost = 'https://cdn.example.com/models';
|
||||
|
||||
// Customize the URL pattern
|
||||
// Default: '{model}/resolve/{revision}/{file}'
|
||||
env.remotePathTemplate = 'custom/{model}/{file}';
|
||||
```
|
||||
|
||||
**Use case:** Self-hosting models, using a CDN for faster downloads, or corporate proxies.
|
||||
|
||||
### Example: Private Model Server
|
||||
|
||||
```javascript
|
||||
import { env, pipeline } from '@huggingface/transformers';
|
||||
|
||||
// Configure custom model host
|
||||
env.remoteHost = 'https://models.mycompany.com';
|
||||
env.remotePathTemplate = '{model}/{file}';
|
||||
|
||||
// Models will be loaded from:
|
||||
// https://models.mycompany.com/my-model/model.onnx
|
||||
const pipe = await pipeline('sentiment-analysis', 'my-model');
|
||||
```
|
||||
|
||||
## Local Model Configuration
|
||||
|
||||
Control loading models from the local file system.
|
||||
|
||||
### Enable Local Models
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Enable local file system loading
|
||||
env.allowLocalModels = true;
|
||||
|
||||
// Set the base path for local models
|
||||
env.localModelPath = '/path/to/models/';
|
||||
```
|
||||
|
||||
**Default values:**
|
||||
- Browser: `allowLocalModels = false`, `localModelPath = '/models/'`
|
||||
- Node.js: `allowLocalModels = true`, `localModelPath = '/models/'`
|
||||
|
||||
### File System Control
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Disable file system usage entirely (Node.js only)
|
||||
env.useFS = false;
|
||||
```
|
||||
|
||||
### Example: Local Model Directory Structure
|
||||
|
||||
```
|
||||
/app/models/
|
||||
├── onnx-community/
|
||||
│ ├── Supertonic-TTS-ONNX/
|
||||
│ │ ├── config.json
|
||||
│ │ ├── tokenizer.json
|
||||
│ │ ├── model.onnx
|
||||
│ │ └── ...
|
||||
│ └── yolo26l-pose-ONNX/
|
||||
│ ├── config.json
|
||||
│ ├── preprocessor_config.json
|
||||
│ ├── model.onnx
|
||||
│ └── ...
|
||||
```
|
||||
|
||||
```javascript
|
||||
env.allowLocalModels = true;
|
||||
env.localModelPath = '/app/models/';
|
||||
env.allowRemoteModels = false; // Offline mode
|
||||
|
||||
const classifier = await pipeline('sentiment-analysis', 'Xenova/distilbert-base-uncased-finetuned-sst-2-english');
|
||||
```
|
||||
|
||||
## Cache Configuration
|
||||
|
||||
Transformers.js supports multiple caching strategies to improve performance and reduce network usage.
|
||||
|
||||
### Quick Configuration
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Browser cache (Cache API)
|
||||
env.useBrowserCache = true; // default: true
|
||||
env.cacheKey = 'my-app-transformers-cache'; // default: 'transformers-cache'
|
||||
|
||||
// Node.js filesystem cache
|
||||
env.useFSCache = true; // default: true
|
||||
env.cacheDir = './custom-cache-dir'; // default: './.cache'
|
||||
|
||||
// Custom cache implementation
|
||||
env.useCustomCache = true;
|
||||
env.customCache = new CustomCache(); // Implement Cache API interface
|
||||
|
||||
// WASM binary caching
|
||||
env.useWasmCache = true; // default: true
|
||||
```
|
||||
|
||||
### Disable Caching
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Disable all caching (re-download on every load)
|
||||
env.useFSCache = false;
|
||||
env.useBrowserCache = false;
|
||||
env.useWasmCache = false;
|
||||
env.cacheDir = null;
|
||||
```
|
||||
|
||||
For comprehensive caching documentation including:
|
||||
- Browser Cache API details and storage limits
|
||||
- Node.js filesystem cache structure and management
|
||||
- Custom cache implementations (Redis, database, S3)
|
||||
- Cache clearing and monitoring strategies
|
||||
- Best practices and troubleshooting
|
||||
|
||||
See **[Caching Reference](./CACHE.md)**
|
||||
|
||||
## WASM Configuration
|
||||
|
||||
Configure ONNX Runtime Web Assembly backend settings.
|
||||
|
||||
### Basic WASM Settings
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Set custom WASM paths
|
||||
env.backends.onnx.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/';
|
||||
|
||||
// Configure number of threads (Node.js only)
|
||||
env.backends.onnx.wasm.numThreads = 4;
|
||||
|
||||
// Enable/disable SIMD (single instruction, multiple data)
|
||||
env.backends.onnx.wasm.simd = true;
|
||||
```
|
||||
|
||||
### Proxy Configuration
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Configure proxy for WASM downloads
|
||||
env.backends.onnx.wasm.proxy = true;
|
||||
```
|
||||
|
||||
### Self-Hosted WASM Files
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Host WASM files on your own server
|
||||
env.backends.onnx.wasm.wasmPaths = '/static/wasm/';
|
||||
```
|
||||
|
||||
**Required files:**
|
||||
- `ort-wasm.wasm` - Main WASM binary
|
||||
- `ort-wasm-simd.wasm` - SIMD-enabled WASM binary
|
||||
- `ort-wasm-threaded.wasm` - Multi-threaded WASM binary
|
||||
- `ort-wasm-simd-threaded.wasm` - SIMD + multi-threaded WASM binary
|
||||
|
||||
## Common Configuration Patterns
|
||||
|
||||
### Development Setup
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Fast iteration with caching
|
||||
env.allowRemoteModels = true;
|
||||
env.useBrowserCache = true; // Browser
|
||||
env.useFSCache = true; // Node.js
|
||||
env.cacheDir = './.cache';
|
||||
```
|
||||
|
||||
### Production (Local Models)
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Secure, offline-capable setup
|
||||
env.allowRemoteModels = false;
|
||||
env.allowLocalModels = true;
|
||||
env.localModelPath = '/app/models/';
|
||||
env.useFSCache = false; // Models already local
|
||||
```
|
||||
|
||||
### Offline-First Application
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Try local first, fall back to remote
|
||||
env.allowLocalModels = true;
|
||||
env.localModelPath = './models/';
|
||||
env.allowRemoteModels = true;
|
||||
env.useFSCache = true;
|
||||
env.cacheDir = './cache';
|
||||
```
|
||||
|
||||
### Custom CDN
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Use your own model hosting
|
||||
env.remoteHost = 'https://cdn.example.com/ml-models';
|
||||
env.remotePathTemplate = '{model}/{file}';
|
||||
env.useBrowserCache = true;
|
||||
```
|
||||
|
||||
### Memory-Constrained Environment
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Minimize disk/memory usage
|
||||
env.useFSCache = false;
|
||||
env.useBrowserCache = false;
|
||||
env.useWasmCache = false;
|
||||
env.cacheDir = null;
|
||||
```
|
||||
|
||||
### Testing/CI Environment
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Predictable, isolated testing
|
||||
env.allowRemoteModels = false;
|
||||
env.allowLocalModels = true;
|
||||
env.localModelPath = './test-fixtures/models/';
|
||||
env.useFSCache = false;
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Environment Best Practices
|
||||
|
||||
### 1. Configure Early
|
||||
|
||||
Set `env` properties before loading any models:
|
||||
|
||||
```javascript
|
||||
import { env, pipeline } from '@huggingface/transformers';
|
||||
|
||||
// ✓ Good: Configure before loading
|
||||
env.allowRemoteModels = false;
|
||||
env.localModelPath = '/app/models/';
|
||||
const pipe = await pipeline('sentiment-analysis');
|
||||
|
||||
// ✗ Bad: Configuring after loading may not take effect
|
||||
const pipe = await pipeline('sentiment-analysis');
|
||||
env.allowRemoteModels = false; // Too late!
|
||||
```
|
||||
|
||||
### 2. Use Environment Variables
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
// Configure based on environment
|
||||
env.allowRemoteModels = process.env.NODE_ENV === 'development';
|
||||
env.cacheDir = process.env.MODEL_CACHE_DIR || './.cache';
|
||||
env.localModelPath = process.env.LOCAL_MODELS_PATH || '/app/models/';
|
||||
```
|
||||
|
||||
### 3. Handle Errors Gracefully
|
||||
|
||||
```javascript
|
||||
import { pipeline, env } from '@huggingface/transformers';
|
||||
|
||||
try {
|
||||
env.allowRemoteModels = false;
|
||||
const pipe = await pipeline('sentiment-analysis', 'my-model');
|
||||
} catch (error) {
|
||||
if (error.message.includes('not found')) {
|
||||
console.error('Model not found locally. Enable remote models or download the model.');
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Log Configuration
|
||||
|
||||
```javascript
|
||||
import { env } from '@huggingface/transformers';
|
||||
|
||||
console.log('Transformers.js Configuration:', {
|
||||
version: env.version,
|
||||
allowRemoteModels: env.allowRemoteModels,
|
||||
allowLocalModels: env.allowLocalModels,
|
||||
localModelPath: env.localModelPath,
|
||||
cacheDir: env.cacheDir,
|
||||
useFSCache: env.useFSCache,
|
||||
useBrowserCache: env.useBrowserCache
|
||||
});
|
||||
```
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- **[Caching Reference](./CACHE.md)** - Comprehensive caching guide (browser, Node.js, custom implementations)
|
||||
- [Pipeline Options](./PIPELINE_OPTIONS.md) - Configure pipeline loading with `progress_callback`, `device`, `dtype`, etc.
|
||||
- [Model Architectures](./MODEL_ARCHITECTURES.md) - Supported models and architectures
|
||||
- [Examples](./EXAMPLES.md) - Code examples for different runtimes
|
||||
- [Main Skill Guide](../SKILL.md) - Getting started and common usage
|
||||
@@ -0,0 +1,605 @@
|
||||
# Transformers.js Code Examples
|
||||
|
||||
Working examples showing how to use Transformers.js across different runtimes and frameworks.
|
||||
|
||||
All examples use the same task and model for consistency:
|
||||
- **Task**: `feature-extraction`
|
||||
- **Model**: `onnx-community/all-MiniLM-L6-v2-ONNX`
|
||||
|
||||
## Table of Contents
|
||||
1. [Browser (Vanilla JS)](#browser-vanilla-js)
|
||||
2. [Node.js](#nodejs)
|
||||
3. [React](#react)
|
||||
4. [Express API](#express-api)
|
||||
|
||||
## Browser (Vanilla JS)
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```html
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Feature Extraction</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Text Embedding Generator</h1>
|
||||
<textarea id="input" placeholder="Enter text to embed..."></textarea>
|
||||
<button onclick="generateEmbedding()">Generate Embedding</button>
|
||||
<div id="result"></div>
|
||||
<div id="loading" style="display:none;">Loading model...</div>
|
||||
|
||||
<script type="module">
|
||||
import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.8.1';
|
||||
|
||||
let extractor;
|
||||
|
||||
// Initialize model on page load
|
||||
document.getElementById('loading').style.display = 'block';
|
||||
extractor = await pipeline(
|
||||
'feature-extraction',
|
||||
'onnx-community/all-MiniLM-L6-v2-ONNX'
|
||||
);
|
||||
document.getElementById('loading').style.display = 'none';
|
||||
|
||||
window.generateEmbedding = async function() {
|
||||
const text = document.getElementById('input').value;
|
||||
const output = await extractor(text, { pooling: 'mean', normalize: true });
|
||||
|
||||
document.getElementById('result').innerHTML = `
|
||||
<h3>Embedding Generated:</h3>
|
||||
<p>Dimensions: ${output.data.length}</p>
|
||||
<p>First 5 values: ${Array.from(output.data).slice(0, 5).join(', ')}</p>
|
||||
`;
|
||||
};
|
||||
|
||||
// Cleanup on page unload
|
||||
window.addEventListener('beforeunload', () => {
|
||||
if (extractor) extractor.dispose();
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
```
|
||||
|
||||
### With Progress Tracking
|
||||
|
||||
```html
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Feature Extraction with Progress</title>
|
||||
<style>
|
||||
.file-progress {
|
||||
margin: 10px 0;
|
||||
}
|
||||
.file-name {
|
||||
font-size: 12px;
|
||||
margin-bottom: 5px;
|
||||
}
|
||||
.progress-bar {
|
||||
width: 100%;
|
||||
height: 20px;
|
||||
background: #f0f0f0;
|
||||
border-radius: 5px;
|
||||
overflow: hidden;
|
||||
}
|
||||
.progress-fill {
|
||||
height: 100%;
|
||||
background: #4CAF50;
|
||||
transition: width 0.3s;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Text Embedding Generator</h1>
|
||||
<div id="loading">
|
||||
<p id="status">Loading model...</p>
|
||||
<div id="progress-container"></div>
|
||||
</div>
|
||||
<div id="app" style="display:none;">
|
||||
<textarea id="input" placeholder="Enter text..."></textarea>
|
||||
<button onclick="generateEmbedding()">Generate</button>
|
||||
<div id="result"></div>
|
||||
</div>
|
||||
|
||||
<script type="module">
|
||||
import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.8.1';
|
||||
|
||||
let extractor;
|
||||
const fileProgressBars = {};
|
||||
const progressContainer = document.getElementById('progress-container');
|
||||
|
||||
extractor = await pipeline(
|
||||
'feature-extraction',
|
||||
'onnx-community/all-MiniLM-L6-v2-ONNX',
|
||||
{
|
||||
progress_callback: (info) => {
|
||||
document.getElementById('status').textContent = `${info.status}: ${info.file}`;
|
||||
|
||||
if (info.status === 'progress') {
|
||||
// Create progress bar for each file
|
||||
if (!fileProgressBars[info.file]) {
|
||||
const fileDiv = document.createElement('div');
|
||||
fileDiv.className = 'file-progress';
|
||||
fileDiv.innerHTML = `
|
||||
<div class="file-name">${info.file}</div>
|
||||
<div class="progress-bar">
|
||||
<div class="progress-fill"></div>
|
||||
</div>
|
||||
`;
|
||||
progressContainer.appendChild(fileDiv);
|
||||
fileProgressBars[info.file] = fileDiv.querySelector('.progress-fill');
|
||||
}
|
||||
|
||||
// Update progress
|
||||
fileProgressBars[info.file].style.width = `${info.progress}%`;
|
||||
}
|
||||
|
||||
if (info.status === 'ready') {
|
||||
document.getElementById('loading').style.display = 'none';
|
||||
document.getElementById('app').style.display = 'block';
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
window.generateEmbedding = async function() {
|
||||
const text = document.getElementById('input').value;
|
||||
const output = await extractor(text, { pooling: 'mean', normalize: true });
|
||||
|
||||
document.getElementById('result').innerHTML = `
|
||||
<p>Embedding: ${output.data.length} dimensions</p>
|
||||
`;
|
||||
};
|
||||
|
||||
// Cleanup on page unload
|
||||
window.addEventListener('beforeunload', () => {
|
||||
if (extractor) extractor.dispose();
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
```
|
||||
|
||||
## Node.js
|
||||
|
||||
### Basic Script
|
||||
|
||||
```javascript
|
||||
// embed.js
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
|
||||
async function generateEmbedding(text) {
|
||||
const extractor = await pipeline(
|
||||
'feature-extraction',
|
||||
'onnx-community/all-MiniLM-L6-v2-ONNX'
|
||||
);
|
||||
|
||||
const output = await extractor(text, { pooling: 'mean', normalize: true });
|
||||
|
||||
console.log('Text:', text);
|
||||
console.log('Embedding dimensions:', output.data.length);
|
||||
console.log('First 5 values:', Array.from(output.data).slice(0, 5));
|
||||
|
||||
await extractor.dispose();
|
||||
}
|
||||
|
||||
generateEmbedding('Hello, world!');
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```javascript
|
||||
// batch-embed.js
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
import fs from 'fs/promises';
|
||||
|
||||
async function embedDocuments(documents) {
|
||||
const extractor = await pipeline(
|
||||
'feature-extraction',
|
||||
'onnx-community/all-MiniLM-L6-v2-ONNX'
|
||||
);
|
||||
|
||||
console.log(`Processing ${documents.length} documents...`);
|
||||
|
||||
const embeddings = [];
|
||||
|
||||
for (let i = 0; i < documents.length; i++) {
|
||||
const output = await extractor(documents[i], {
|
||||
pooling: 'mean',
|
||||
normalize: true
|
||||
});
|
||||
|
||||
embeddings.push({
|
||||
text: documents[i],
|
||||
embedding: Array.from(output.data)
|
||||
});
|
||||
|
||||
console.log(`Processed ${i + 1}/${documents.length}`);
|
||||
}
|
||||
|
||||
await fs.writeFile(
|
||||
'embeddings.json',
|
||||
JSON.stringify(embeddings, null, 2)
|
||||
);
|
||||
|
||||
console.log('Saved to embeddings.json');
|
||||
|
||||
await extractor.dispose();
|
||||
}
|
||||
|
||||
const documents = [
|
||||
'The cat sat on the mat',
|
||||
'A dog played in the park',
|
||||
'Machine learning is fascinating'
|
||||
];
|
||||
|
||||
embedDocuments(documents);
|
||||
```
|
||||
|
||||
### CLI with Progress
|
||||
|
||||
```javascript
|
||||
// cli-embed.js
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
|
||||
async function main() {
|
||||
const text = process.argv[2] || 'Hello, world!';
|
||||
|
||||
console.log('Loading model...');
|
||||
|
||||
const fileProgress = {};
|
||||
|
||||
const extractor = await pipeline(
|
||||
'feature-extraction',
|
||||
'onnx-community/all-MiniLM-L6-v2-ONNX',
|
||||
{
|
||||
progress_callback: (info) => {
|
||||
if (info.status === 'progress') {
|
||||
fileProgress[info.file] = info.progress;
|
||||
|
||||
// Show all files progress
|
||||
const progressLines = Object.entries(fileProgress)
|
||||
.map(([file, progress]) => ` ${file}: ${progress.toFixed(1)}%`)
|
||||
.join('\n');
|
||||
|
||||
process.stdout.write(`\r\x1b[K${progressLines}`);
|
||||
}
|
||||
|
||||
if (info.status === 'done') {
|
||||
console.log(`\n✓ ${info.file} complete`);
|
||||
}
|
||||
|
||||
if (info.status === 'ready') {
|
||||
console.log('\nModel ready!');
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
console.log('Generating embedding...');
|
||||
const output = await extractor(text, { pooling: 'mean', normalize: true });
|
||||
|
||||
console.log(`\nText: "${text}"`);
|
||||
console.log(`Dimensions: ${output.data.length}`);
|
||||
console.log(`First 5 values: ${Array.from(output.data).slice(0, 5).join(', ')}`);
|
||||
|
||||
await extractor.dispose();
|
||||
}
|
||||
|
||||
main();
|
||||
```
|
||||
|
||||
## React
|
||||
|
||||
### Basic Component
|
||||
|
||||
```jsx
|
||||
// EmbeddingGenerator.jsx
|
||||
import { useState, useRef, useEffect } from 'react';
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
|
||||
export function EmbeddingGenerator() {
|
||||
const extractorRef = useRef(null);
|
||||
const [text, setText] = useState('');
|
||||
const [embedding, setEmbedding] = useState(null);
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const generate = async () => {
|
||||
if (!text) return;
|
||||
|
||||
setLoading(true);
|
||||
|
||||
// Load model on first generate
|
||||
if (!extractorRef.current) {
|
||||
extractorRef.current = await pipeline(
|
||||
'feature-extraction',
|
||||
'onnx-community/all-MiniLM-L6-v2-ONNX'
|
||||
);
|
||||
}
|
||||
|
||||
const output = await extractorRef.current(text, {
|
||||
pooling: 'mean',
|
||||
normalize: true
|
||||
});
|
||||
setEmbedding(Array.from(output.data));
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
// Cleanup on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (extractorRef.current) {
|
||||
extractorRef.current.dispose();
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h2>Text Embedding Generator</h2>
|
||||
|
||||
<textarea
|
||||
value={text}
|
||||
onChange={(e) => setText(e.target.value)}
|
||||
placeholder="Enter text"
|
||||
disabled={loading}
|
||||
/>
|
||||
|
||||
<button onClick={generate} disabled={loading || !text}>
|
||||
{loading ? 'Processing...' : 'Generate Embedding'}
|
||||
</button>
|
||||
|
||||
{embedding && (
|
||||
<div>
|
||||
<h3>Result:</h3>
|
||||
<p>Dimensions: {embedding.length}</p>
|
||||
<p>First 5 values: {embedding.slice(0, 5).join(', ')}</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
### With Progress Tracking
|
||||
|
||||
```jsx
|
||||
// EmbeddingGeneratorWithProgress.jsx
|
||||
import { useState, useRef, useEffect } from 'react';
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
|
||||
export function EmbeddingGeneratorWithProgress() {
|
||||
const extractorRef = useRef(null);
|
||||
const [text, setText] = useState('');
|
||||
const [embedding, setEmbedding] = useState(null);
|
||||
const [fileProgress, setFileProgress] = useState({});
|
||||
const [status, setStatus] = useState('');
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const generate = async () => {
|
||||
if (!text) return;
|
||||
|
||||
setLoading(true);
|
||||
|
||||
// Load model on first generate
|
||||
if (!extractorRef.current) {
|
||||
setStatus('Loading model...');
|
||||
|
||||
extractorRef.current = await pipeline(
|
||||
'feature-extraction',
|
||||
'onnx-community/all-MiniLM-L6-v2-ONNX',
|
||||
{
|
||||
progress_callback: (info) => {
|
||||
setStatus(`${info.status}: ${info.file}`);
|
||||
|
||||
if (info.status === 'progress') {
|
||||
setFileProgress(prev => ({
|
||||
...prev,
|
||||
[info.file]: info.progress
|
||||
}));
|
||||
}
|
||||
|
||||
if (info.status === 'ready') {
|
||||
setStatus('Model ready!');
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
setStatus('Generating embedding...');
|
||||
const output = await extractorRef.current(text, {
|
||||
pooling: 'mean',
|
||||
normalize: true
|
||||
});
|
||||
setEmbedding(Array.from(output.data));
|
||||
setStatus('Complete!');
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
// Cleanup on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (extractorRef.current) {
|
||||
extractorRef.current.dispose();
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h2>Text Embedding Generator</h2>
|
||||
|
||||
{loading && Object.keys(fileProgress).length > 0 && (
|
||||
<div>
|
||||
<p>{status}</p>
|
||||
{Object.entries(fileProgress).map(([file, progress]) => (
|
||||
<div key={file} style={{ margin: '10px 0' }}>
|
||||
<div style={{ fontSize: '12px', marginBottom: '5px' }}>{file}</div>
|
||||
<div style={{ width: '100%', height: '20px', background: '#f0f0f0', borderRadius: '5px', overflow: 'hidden' }}>
|
||||
<div
|
||||
style={{
|
||||
width: `${progress}%`,
|
||||
height: '100%',
|
||||
background: '#4CAF50',
|
||||
transition: 'width 0.3s'
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<textarea
|
||||
value={text}
|
||||
onChange={(e) => setText(e.target.value)}
|
||||
placeholder="Enter text"
|
||||
disabled={loading}
|
||||
/>
|
||||
|
||||
<button onClick={generate} disabled={loading || !text}>
|
||||
{loading ? 'Processing...' : 'Generate Embedding'}
|
||||
</button>
|
||||
|
||||
{embedding && (
|
||||
<div>
|
||||
<h3>Result:</h3>
|
||||
<p>Dimensions: {embedding.length}</p>
|
||||
<p>First 5 values: {embedding.slice(0, 5).join(', ')}</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
## Express API
|
||||
|
||||
### Basic API Server
|
||||
|
||||
```javascript
|
||||
// server.js
|
||||
import express from 'express';
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
|
||||
const app = express();
|
||||
app.use(express.json());
|
||||
|
||||
// Initialize model once at startup
|
||||
let extractor;
|
||||
(async () => {
|
||||
console.log('Loading model...');
|
||||
extractor = await pipeline(
|
||||
'feature-extraction',
|
||||
'onnx-community/all-MiniLM-L6-v2-ONNX'
|
||||
);
|
||||
console.log('Model ready!');
|
||||
})();
|
||||
|
||||
app.post('/embed', async (req, res) => {
|
||||
try {
|
||||
const { text } = req.body;
|
||||
|
||||
if (!text) {
|
||||
return res.status(400).json({ error: 'Text is required' });
|
||||
}
|
||||
|
||||
const output = await extractor(text, {
|
||||
pooling: 'mean',
|
||||
normalize: true
|
||||
});
|
||||
|
||||
res.json({
|
||||
text,
|
||||
embedding: Array.from(output.data),
|
||||
dimensions: output.data.length
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
res.status(500).json({ error: 'Failed to generate embedding' });
|
||||
}
|
||||
});
|
||||
|
||||
app.listen(3000, () => {
|
||||
console.log('Server running on http://localhost:3000');
|
||||
});
|
||||
```
|
||||
|
||||
### API with Graceful Shutdown
|
||||
|
||||
```javascript
|
||||
// server-with-shutdown.js
|
||||
import express from 'express';
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
|
||||
const app = express();
|
||||
app.use(express.json());
|
||||
|
||||
let extractor;
|
||||
let server;
|
||||
|
||||
async function initialize() {
|
||||
console.log('Loading model...');
|
||||
extractor = await pipeline(
|
||||
'feature-extraction',
|
||||
'onnx-community/all-MiniLM-L6-v2-ONNX'
|
||||
);
|
||||
console.log('Model ready!');
|
||||
}
|
||||
|
||||
app.post('/embed', async (req, res) => {
|
||||
try {
|
||||
const { text } = req.body;
|
||||
|
||||
if (!text) {
|
||||
return res.status(400).json({ error: 'Text is required' });
|
||||
}
|
||||
|
||||
const output = await extractor(text, {
|
||||
pooling: 'mean',
|
||||
normalize: true
|
||||
});
|
||||
|
||||
res.json({
|
||||
embedding: Array.from(output.data),
|
||||
dimensions: output.data.length
|
||||
});
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
async function shutdown(signal) {
|
||||
console.log(`\n${signal} received. Shutting down...`);
|
||||
|
||||
if (server) {
|
||||
server.close(() => {
|
||||
console.log('HTTP server closed');
|
||||
});
|
||||
}
|
||||
|
||||
if (extractor) {
|
||||
console.log('Disposing model...');
|
||||
await extractor.dispose();
|
||||
console.log('Model disposed');
|
||||
}
|
||||
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
process.on('SIGTERM', () => shutdown('SIGTERM'));
|
||||
process.on('SIGINT', () => shutdown('SIGINT'));
|
||||
|
||||
initialize().then(() => {
|
||||
server = app.listen(3000, () => {
|
||||
console.log('Server running on http://localhost:3000');
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
These examples demonstrate the same functionality across different runtimes and frameworks, making it easy to adapt to your specific use case. All examples include proper cleanup with `.dispose()` to free memory.
|
||||
@@ -0,0 +1,167 @@
|
||||
# Supported Model Architectures
|
||||
|
||||
This document lists the model architectures currently supported by Transformers.js.
|
||||
|
||||
## Natural Language Processing
|
||||
|
||||
### Text Models
|
||||
- **ALBERT** - A Lite BERT for Self-supervised Learning
|
||||
- **BERT** - Bidirectional Encoder Representations from Transformers
|
||||
- **CamemBERT** - French language model based on RoBERTa
|
||||
- **CodeGen** - Code generation models
|
||||
- **CodeLlama** - Code-focused Llama models
|
||||
- **Cohere** - Command-R models for RAG
|
||||
- **DeBERTa** - Decoding-enhanced BERT with Disentangled Attention
|
||||
- **DeBERTa-v2** - Improved version of DeBERTa
|
||||
- **DistilBERT** - Distilled version of BERT (smaller, faster)
|
||||
- **GPT-2** - Generative Pre-trained Transformer 2
|
||||
- **GPT-Neo** - Open source GPT-3 alternative
|
||||
- **GPT-NeoX** - Larger GPT-Neo models
|
||||
- **LLaMA** - Large Language Model Meta AI
|
||||
- **Mistral** - Mistral AI language models
|
||||
- **MPNet** - Masked and Permuted Pre-training
|
||||
- **MobileBERT** - Compressed BERT for mobile devices
|
||||
- **RoBERTa** - Robustly Optimized BERT
|
||||
- **T5** - Text-to-Text Transfer Transformer
|
||||
- **XLM-RoBERTa** - Multilingual RoBERTa
|
||||
|
||||
### Sequence-to-Sequence
|
||||
- **BART** - Denoising Sequence-to-Sequence Pre-training
|
||||
- **Blenderbot** - Open-domain chatbot
|
||||
- **BlenderbotSmall** - Smaller Blenderbot variant
|
||||
- **M2M100** - Many-to-Many multilingual translation
|
||||
- **MarianMT** - Neural machine translation
|
||||
- **mBART** - Multilingual BART
|
||||
- **NLLB** - No Language Left Behind (200 languages)
|
||||
- **Pegasus** - Pre-training with extracted gap-sentences
|
||||
|
||||
## Computer Vision
|
||||
|
||||
### Image Classification
|
||||
- **BEiT** - BERT Pre-Training of Image Transformers
|
||||
- **ConvNeXT** - Modern ConvNet architecture
|
||||
- **ConvNeXTV2** - Improved ConvNeXT
|
||||
- **DeiT** - Data-efficient Image Transformers
|
||||
- **DINOv2** - Self-supervised Vision Transformer
|
||||
- **DINOv3** - Latest DINO iteration
|
||||
- **EfficientNet** - Efficient convolutional networks
|
||||
- **MobileNet** - Lightweight models for mobile
|
||||
- **MobileViT** - Mobile Vision Transformer
|
||||
- **ResNet** - Residual Networks
|
||||
- **SegFormer** - Semantic segmentation transformer
|
||||
- **Swin** - Shifted Window Transformer
|
||||
- **ViT** - Vision Transformer
|
||||
|
||||
### Object Detection
|
||||
- **DETR** - Detection Transformer
|
||||
- **D-FINE** - Fine-grained Distribution Refinement for object detection
|
||||
- **DINO** - DETR with Improved deNoising anchOr boxes
|
||||
- **Grounding DINO** - Open-set object detection
|
||||
- **YOLOS** - You Only Look at One Sequence
|
||||
|
||||
### Segmentation
|
||||
- **CLIPSeg** - Image segmentation with text prompts
|
||||
- **Mask2Former** - Universal image segmentation
|
||||
- **SAM** - Segment Anything Model
|
||||
- **EdgeTAM** - On-Device Track Anything Model
|
||||
|
||||
### Depth & Pose
|
||||
- **DPT** - Dense Prediction Transformer
|
||||
- **Depth Anything** - Monocular depth estimation
|
||||
- **Depth Pro** - Sharp monocular metric depth
|
||||
- **GLPN** - Global-Local Path Networks for depth
|
||||
|
||||
## Audio
|
||||
|
||||
### Speech Recognition
|
||||
- **Wav2Vec2** - Self-supervised speech representations
|
||||
- **Whisper** - Robust speech recognition (multilingual)
|
||||
- **HuBERT** - Self-supervised speech representation learning
|
||||
|
||||
### Audio Processing
|
||||
- **Audio Spectrogram Transformer** - Audio classification
|
||||
- **DAC** - Descript Audio Codec
|
||||
|
||||
### Text-to-Speech
|
||||
- **SpeechT5** - Unified speech and text pre-training
|
||||
- **VITS** - Conditional Variational Autoencoder with adversarial learning
|
||||
|
||||
## Multimodal
|
||||
|
||||
### Vision-Language
|
||||
- **CLIP** - Contrastive Language-Image Pre-training
|
||||
- **Chinese-CLIP** - Chinese version of CLIP
|
||||
- **ALIGN** - Large-scale noisy image-text pairs
|
||||
- **BLIP** - Bootstrapping Language-Image Pre-training
|
||||
- **Florence-2** - Unified vision foundation model
|
||||
- **LLaVA** - Large Language and Vision Assistant
|
||||
- **Moondream** - Tiny vision-language model
|
||||
|
||||
### Document Understanding
|
||||
- **DiT** - Document Image Transformer
|
||||
- **Donut** - OCR-free Document Understanding
|
||||
- **LayoutLM** - Pre-training for document understanding
|
||||
- **TrOCR** - Transformer-based OCR
|
||||
|
||||
### Audio-Language
|
||||
- **CLAP** - Contrastive Language-Audio Pre-training
|
||||
|
||||
## Embeddings & Similarity
|
||||
|
||||
- **Sentence Transformers** - Sentence embeddings
|
||||
- **all-MiniLM** - Efficient sentence embeddings
|
||||
- **all-mpnet-base** - High-quality sentence embeddings
|
||||
- **E5** - Text embeddings by Microsoft
|
||||
- **BGE** - General embedding models
|
||||
- **nomic-embed** - Long context embeddings
|
||||
|
||||
## Specialized Models
|
||||
|
||||
### Code
|
||||
- **CodeBERT** - Pre-trained model for code
|
||||
- **GraphCodeBERT** - Code structure understanding
|
||||
- **StarCoder** - Code generation
|
||||
|
||||
### Scientific
|
||||
- **SciBERT** - Scientific text
|
||||
- **BioBERT** - Biomedical text
|
||||
|
||||
### Retrieval
|
||||
- **ColBERT** - Contextualized late interaction over BERT
|
||||
- **DPR** - Dense Passage Retrieval
|
||||
|
||||
## Model Selection Tips
|
||||
|
||||
### For Text Tasks
|
||||
- **Small & Fast**: DistilBERT, MobileBERT
|
||||
- **Balanced**: BERT-base, RoBERTa-base
|
||||
- **High Accuracy**: RoBERTa-large, DeBERTa-v3-large
|
||||
- **Multilingual**: XLM-RoBERTa, mBERT
|
||||
|
||||
### For Vision Tasks
|
||||
- **Mobile/Browser**: MobileNet, EfficientNet-B0
|
||||
- **Balanced**: DeiT-base, ConvNeXT-tiny
|
||||
- **High Accuracy**: Swin-large, DINOv2-large
|
||||
|
||||
### For Audio Tasks
|
||||
- **Speech Recognition**: Whisper-tiny (fast), Whisper-large (accurate)
|
||||
- **Audio Classification**: Audio Spectrogram Transformer
|
||||
|
||||
### For Multimodal
|
||||
- **Vision-Language**: CLIP (general), Florence-2 (comprehensive)
|
||||
- **Document AI**: Donut, LayoutLM
|
||||
- **OCR**: TrOCR
|
||||
|
||||
## Finding Models on Hugging Face Hub
|
||||
|
||||
Search for compatible models:
|
||||
```
|
||||
https://huggingface.co/models?library=transformers.js
|
||||
```
|
||||
|
||||
Filter by task:
|
||||
```
|
||||
https://huggingface.co/models?pipeline_tag=text-classification&library=transformers.js
|
||||
```
|
||||
|
||||
Check for ONNX support by looking for `onnx/` folder in model repository.
|
||||
@@ -0,0 +1,545 @@
|
||||
# Pipeline Options Reference
|
||||
|
||||
Guide to configuring model loading and inference using the `PretrainedModelOptions` parameter in the `pipeline()` function.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Overview](#overview)
|
||||
2. [Basic Options](#basic-options)
|
||||
3. [Model Loading Options](#model-loading-options)
|
||||
4. [Device and Performance Options](#device-and-performance-options)
|
||||
5. [Common Configuration Patterns](#common-configuration-patterns)
|
||||
|
||||
## Overview
|
||||
|
||||
The `pipeline()` function accepts three parameters:
|
||||
|
||||
```javascript
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
|
||||
const pipe = await pipeline(
|
||||
'task-name', // 1. Task type (e.g., 'sentiment-analysis')
|
||||
'model-id', // 2. Model identifier (optional, uses default if null)
|
||||
options // 3. PretrainedModelOptions (optional)
|
||||
);
|
||||
```
|
||||
|
||||
The third parameter, `options`, allows you to configure how the model is loaded and executed.
|
||||
|
||||
### Available Options
|
||||
|
||||
```typescript
|
||||
interface PretrainedModelOptions {
|
||||
// Progress tracking
|
||||
progress_callback?: (info: ProgressInfo) => void;
|
||||
|
||||
// Model configuration
|
||||
config?: PretrainedConfig;
|
||||
|
||||
// Cache and loading
|
||||
cache_dir?: string;
|
||||
local_files_only?: boolean;
|
||||
revision?: string;
|
||||
|
||||
// Model-specific settings
|
||||
subfolder?: string;
|
||||
model_file_name?: string;
|
||||
|
||||
// Device and performance
|
||||
device?: DeviceType | Record<string, DeviceType>;
|
||||
dtype?: DataType | Record<string, DataType>;
|
||||
|
||||
// External data format (large models)
|
||||
use_external_data_format?: boolean | number | Record<string, boolean | number>;
|
||||
|
||||
// ONNX Runtime settings
|
||||
session_options?: InferenceSession.SessionOptions;
|
||||
}
|
||||
```
|
||||
|
||||
## Basic Options
|
||||
|
||||
### Progress Callback
|
||||
|
||||
Track model download and loading progress. **Note:** Models consist of multiple files (model weights, config, tokenizer, etc.), and each file reports its own progress:
|
||||
|
||||
```javascript
|
||||
const fileProgress = {};
|
||||
|
||||
const pipe = await pipeline('sentiment-analysis', null, {
|
||||
progress_callback: (info) => {
|
||||
if (info.status === 'progress') {
|
||||
fileProgress[info.file] = info.progress;
|
||||
console.log(`${info.file}: ${info.progress.toFixed(1)}%`);
|
||||
}
|
||||
|
||||
if (info.status === 'done') {
|
||||
console.log(`✓ ${info.file} complete`);
|
||||
}
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
**Progress Info Types:**
|
||||
|
||||
```typescript
|
||||
type ProgressInfo = {
|
||||
status: 'initiate' | 'download' | 'progress' | 'done' | 'ready';
|
||||
name: string; // Model id or path
|
||||
file: string; // File being processed
|
||||
progress?: number; // Percentage (0-100, only for 'progress' status)
|
||||
loaded?: number; // Bytes downloaded (only for 'progress' status)
|
||||
total?: number; // Total bytes (only for 'progress' status)
|
||||
};
|
||||
```
|
||||
|
||||
**Example: Browser Loading UI with Multiple Files**
|
||||
|
||||
```javascript
|
||||
const statusDiv = document.getElementById('status');
|
||||
const progressContainer = document.getElementById('progress-container');
|
||||
const fileProgressBars = {};
|
||||
|
||||
const pipe = await pipeline('image-classification', null, {
|
||||
progress_callback: (info) => {
|
||||
if (info.status === 'progress') {
|
||||
// Create progress bar for each file if not exists
|
||||
if (!fileProgressBars[info.file]) {
|
||||
const fileDiv = document.createElement('div');
|
||||
fileDiv.innerHTML = `
|
||||
<div class="file-name">${info.file}</div>
|
||||
<div class="progress-bar">
|
||||
<div class="progress-fill" style="width: 0%"></div>
|
||||
</div>
|
||||
`;
|
||||
progressContainer.appendChild(fileDiv);
|
||||
fileProgressBars[info.file] = fileDiv.querySelector('.progress-fill');
|
||||
}
|
||||
|
||||
// Update progress bar
|
||||
fileProgressBars[info.file].style.width = `${info.progress}%`;
|
||||
|
||||
const mb = (info.loaded / 1024 / 1024).toFixed(2);
|
||||
const totalMb = (info.total / 1024 / 1024).toFixed(2);
|
||||
statusDiv.textContent = `${info.file}: ${mb}/${totalMb} MB`;
|
||||
}
|
||||
|
||||
if (info.status === 'ready') {
|
||||
statusDiv.textContent = 'Model ready!';
|
||||
}
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
For more progress tracking examples, see the examples in this section above.
|
||||
|
||||
### Custom Configuration
|
||||
|
||||
Override the model's default configuration:
|
||||
|
||||
```javascript
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
|
||||
const pipe = await pipeline('text-generation', 'model-id', {
|
||||
config: {
|
||||
max_length: 512,
|
||||
temperature: 0.8,
|
||||
// ... other config options
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
**Use cases:**
|
||||
- Override default generation parameters
|
||||
- Adjust model-specific settings
|
||||
- Test different configurations without modifying model files
|
||||
|
||||
## Model Loading Options
|
||||
|
||||
### Cache Directory
|
||||
|
||||
Specify where to cache downloaded models:
|
||||
|
||||
```javascript
|
||||
// Node.js: Custom cache location
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
cache_dir: './my-custom-cache'
|
||||
});
|
||||
```
|
||||
|
||||
**Default behavior:**
|
||||
- If not specified, uses `env.cacheDir` (default: `./.cache`)
|
||||
- Only applies when `env.useFSCache = true` (Node.js)
|
||||
- Browser cache uses Cache API (configured via `env.cacheKey`)
|
||||
|
||||
|
||||
|
||||
### Local Files Only
|
||||
|
||||
Prevent any network requests:
|
||||
|
||||
```javascript
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
local_files_only: true
|
||||
});
|
||||
```
|
||||
|
||||
**Use cases:**
|
||||
- Offline applications
|
||||
- Air-gapped environments
|
||||
- Testing with pre-downloaded models
|
||||
- Production deployments with bundled models
|
||||
|
||||
**Important:**
|
||||
- Model must already be cached or available locally
|
||||
- Throws error if model not found locally
|
||||
- Requires `env.allowLocalModels = true`
|
||||
|
||||
|
||||
|
||||
### Model Revision
|
||||
|
||||
Specify a specific model version (git branch, tag, or commit):
|
||||
|
||||
```javascript
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
revision: 'v1.0.0' // Use specific version
|
||||
});
|
||||
|
||||
// Or use a branch
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
revision: 'experimental'
|
||||
});
|
||||
|
||||
// Or use a commit hash
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
revision: 'abc123def456'
|
||||
});
|
||||
```
|
||||
|
||||
**Default:** `'main'` (latest version)
|
||||
|
||||
**Use cases:**
|
||||
- Pin to stable release for production
|
||||
- Test experimental features
|
||||
- Reproduce results with specific model version
|
||||
- Work with models under development
|
||||
|
||||
**Important:**
|
||||
- Only applies to remote models (Hugging Face Hub)
|
||||
- Ignored for local file paths
|
||||
- Each revision is cached separately
|
||||
|
||||
### Model Subfolder
|
||||
|
||||
Specify the subfolder within the model repository:
|
||||
|
||||
```javascript
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
subfolder: 'onnx' // Default: 'onnx'
|
||||
});
|
||||
```
|
||||
|
||||
**Default:** `'onnx'`
|
||||
|
||||
**Use cases:**
|
||||
- Custom model repository structure
|
||||
- Multiple model variants in same repo
|
||||
- Organizational preferences
|
||||
|
||||
|
||||
|
||||
### Model File Name
|
||||
|
||||
Specify a custom model file name (without `.onnx` extension):
|
||||
|
||||
```javascript
|
||||
const pipe = await pipeline('text-generation', 'model-id', {
|
||||
model_file_name: 'decoder_model_merged'
|
||||
});
|
||||
// Loads: decoder_model_merged.onnx
|
||||
```
|
||||
|
||||
**Use cases:**
|
||||
- Models with non-standard file names
|
||||
- Select specific model variant
|
||||
- Encoder-decoder models with separate files
|
||||
|
||||
**Note:** Currently only valid for encoder-only or decoder-only models.
|
||||
|
||||
|
||||
|
||||
## Device and Performance Options
|
||||
|
||||
### Device Selection
|
||||
|
||||
Choose where to run the model:
|
||||
|
||||
```javascript
|
||||
// Run on CPU (WASM - default)
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
device: 'wasm'
|
||||
});
|
||||
|
||||
// Run on GPU (WebGPU)
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
device: 'webgpu'
|
||||
});
|
||||
```
|
||||
|
||||
**Common devices:**
|
||||
- `'wasm'` - WebAssembly (CPU, most compatible)
|
||||
- `'webgpu'` - WebGPU (GPU, faster in browsers)
|
||||
- `'cpu'` - CPU
|
||||
- `'gpu'` - Auto-detect GPU
|
||||
- `'cuda'` - NVIDIA CUDA (Node.js with GPU)
|
||||
|
||||
See the full list in the [devices.js source](https://github.com/huggingface/transformers.js/blob/main/src/utils/devices.js).
|
||||
|
||||
**Per-component device selection:**
|
||||
|
||||
For models with multiple components (encoder-decoder, vision-encoder-decoder, etc.):
|
||||
|
||||
```javascript
|
||||
const pipe = await pipeline('automatic-speech-recognition', 'model-id', {
|
||||
device: {
|
||||
encoder: 'webgpu', // Run encoder on GPU
|
||||
decoder: 'wasm' // Run decoder on CPU
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
**WebGPU Requirements:**
|
||||
- Chrome/Edge 113+
|
||||
- Enable chrome://flags/#enable-unsafe-webgpu (if needed)
|
||||
- Adequate GPU memory
|
||||
|
||||
|
||||
|
||||
### Data Type (Quantization)
|
||||
|
||||
Control model precision and size:
|
||||
|
||||
```javascript
|
||||
// Full precision (largest, most accurate)
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
dtype: 'fp32'
|
||||
});
|
||||
|
||||
// Half precision (balanced)
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
dtype: 'fp16'
|
||||
});
|
||||
|
||||
// 8-bit quantization (smaller, faster)
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
dtype: 'q8'
|
||||
});
|
||||
|
||||
// 4-bit quantization (smallest, fastest)
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
dtype: 'q4'
|
||||
});
|
||||
```
|
||||
|
||||
**Common data types:**
|
||||
- `'fp32'` - 32-bit floating point (full precision)
|
||||
- `'fp16'` - 16-bit floating point (half precision)
|
||||
- `'q8'` - 8-bit quantized (good balance)
|
||||
- `'q4'` - 4-bit quantized (maximum compression)
|
||||
- `'int8'` - 8-bit integer
|
||||
- `'uint8'` - 8-bit unsigned integer
|
||||
|
||||
See the full list in the [dtypes.js source](https://github.com/huggingface/transformers.js/blob/main/src/utils/dtypes.js).
|
||||
|
||||
**Per-component data type:**
|
||||
|
||||
```javascript
|
||||
const pipe = await pipeline('automatic-speech-recognition', 'model-id', {
|
||||
dtype: {
|
||||
encoder: 'fp32', // Encoder at full precision
|
||||
decoder: 'q8' // Decoder quantized
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
**Trade-offs:**
|
||||
|
||||
| Data Type | Model Size | Speed | Accuracy | Use Case |
|
||||
|-----------|-----------|-------|----------|----------|
|
||||
| `fp32` | Largest | Slowest | Highest | Research, maximum quality |
|
||||
| `fp16` | Medium | Medium | High | Production, GPU inference |
|
||||
| `q8` | Small | Fast | Good | Production, CPU inference |
|
||||
| `q4` | Smallest | Fastest | Acceptable | Edge devices, real-time apps |
|
||||
|
||||
|
||||
|
||||
### External Data Format
|
||||
|
||||
For models >= 2GB, ONNX uses external data format:
|
||||
|
||||
```javascript
|
||||
// Automatically detect and load external data
|
||||
const pipe = await pipeline('text-generation', 'large-model-id', {
|
||||
use_external_data_format: true
|
||||
});
|
||||
|
||||
// Specify number of external data chunks
|
||||
const pipe = await pipeline('text-generation', 'large-model-id', {
|
||||
use_external_data_format: 5 // Load 5 chunks (model.onnx_data_0 to _4)
|
||||
});
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
- Models >= 2GB split weights into separate files
|
||||
- Main file: `model.onnx` (structure only)
|
||||
- Data files: `model.onnx_data` or `model.onnx_data_0`, `model.onnx_data_1`, etc.
|
||||
|
||||
**Default behavior:**
|
||||
- `false` - No external data (models < 2GB)
|
||||
- `true` - Load external data automatically
|
||||
- `number` - Load this many external data chunks
|
||||
|
||||
**Maximum chunks:** 100 (defined by `MAX_EXTERNAL_DATA_CHUNKS`)
|
||||
|
||||
**Per-component external data:**
|
||||
|
||||
```javascript
|
||||
const pipe = await pipeline('text-generation', 'large-model-id', {
|
||||
use_external_data_format: {
|
||||
encoder: true,
|
||||
decoder: 3 // Decoder has 3 external data chunks
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Session Options
|
||||
|
||||
Advanced ONNX Runtime configuration:
|
||||
|
||||
```javascript
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
session_options: {
|
||||
executionProviders: ['webgpu', 'wasm'],
|
||||
graphOptimizationLevel: 'all',
|
||||
enableCpuMemArena: true,
|
||||
enableMemPattern: true,
|
||||
executionMode: 'sequential',
|
||||
logSeverityLevel: 2,
|
||||
logVerbosityLevel: 0
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
**Common session options:**
|
||||
|
||||
| Option | Description | Default |
|
||||
|--------|-------------|---------|
|
||||
| `executionProviders` | Ordered list of execution providers | `['wasm']` |
|
||||
| `graphOptimizationLevel` | Graph optimization: `'disabled'`, `'basic'`, `'extended'`, `'all'` | `'all'` |
|
||||
| `enableCpuMemArena` | Enable CPU memory arena for faster memory allocation | `true` |
|
||||
| `enableMemPattern` | Enable memory pattern optimization | `true` |
|
||||
| `executionMode` | `'sequential'` or `'parallel'` | `'sequential'` |
|
||||
| `logSeverityLevel` | 0=Verbose, 1=Info, 2=Warning, 3=Error, 4=Fatal | `2` |
|
||||
| `freeDimensionOverrides` | Override dynamic dimensions (e.g., `{ batch_size: 1 }`) | - |
|
||||
|
||||
**Use cases:**
|
||||
- Fine-tune performance for specific hardware
|
||||
- Debug model execution issues
|
||||
- Override dynamic shapes
|
||||
- Control memory usage
|
||||
|
||||
|
||||
|
||||
## Common Configuration Patterns
|
||||
|
||||
### Development
|
||||
|
||||
Fast iteration with progress tracking:
|
||||
|
||||
```javascript
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
|
||||
const pipe = await pipeline('sentiment-analysis', null, {
|
||||
progress_callback: (info) => {
|
||||
if (info.status === 'progress') {
|
||||
console.log(`${info.file}: ${info.progress.toFixed(1)}%`);
|
||||
}
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
### Production (GPU)
|
||||
|
||||
Use WebGPU with fp16 for better performance:
|
||||
|
||||
```javascript
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
device: 'webgpu',
|
||||
dtype: 'fp16'
|
||||
});
|
||||
```
|
||||
|
||||
### Production (CPU)
|
||||
|
||||
Use quantization for smaller size and faster CPU inference:
|
||||
|
||||
```javascript
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
dtype: 'q8' // or 'q4' for even smaller
|
||||
});
|
||||
```
|
||||
|
||||
### Offline/Local
|
||||
|
||||
Prevent network requests, use only local models:
|
||||
|
||||
```javascript
|
||||
import { pipeline, env } from '@huggingface/transformers';
|
||||
|
||||
env.allowLocalModels = true;
|
||||
env.localModelPath = './models/';
|
||||
|
||||
const pipe = await pipeline('sentiment-analysis', 'model-id', {
|
||||
local_files_only: true
|
||||
});
|
||||
```
|
||||
|
||||
### Per-Component Settings
|
||||
|
||||
For encoder-decoder models, configure each component separately:
|
||||
|
||||
```javascript
|
||||
const pipe = await pipeline('automatic-speech-recognition', 'model-id', {
|
||||
device: {
|
||||
encoder: 'webgpu',
|
||||
decoder: 'wasm'
|
||||
},
|
||||
dtype: {
|
||||
encoder: 'fp16',
|
||||
decoder: 'q8'
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [Configuration Reference](./CONFIGURATION.md) - Environment configuration with `env` object
|
||||
- [Text Generation Guide](./TEXT_GENERATION.md) - Text generation options and streaming
|
||||
- [Model Architectures](./MODEL_ARCHITECTURES.md) - Supported models and selection tips
|
||||
- [Main Skill Guide](../SKILL.md) - Getting started with Transformers.js
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Progress Callbacks**: Use `progress_callback` for large models to show download progress
|
||||
2. **Quantization**: Use `q8` or `q4` for CPU inference to reduce size and improve speed
|
||||
3. **Device Selection**: Use `webgpu` for better performance when available
|
||||
4. **Offline-First**: Use `local_files_only: true` in production to avoid runtime downloads
|
||||
5. **Version Pinning**: Use `revision` to pin model versions for reproducible deployments
|
||||
6. **Memory Management**: Always dispose pipelines with `pipe.dispose()` when done
|
||||
|
||||
---
|
||||
|
||||
This document covers all available options for the `pipeline()` function. For environment-level configuration (remote hosts, global cache settings, WASM paths), see the [Configuration Reference](./CONFIGURATION.md).
|
||||
@@ -0,0 +1,315 @@
|
||||
# Text Generation Guide
|
||||
|
||||
Guide to generating text with Transformers.js, including streaming and chat format.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Basic Generation](#basic-generation)
|
||||
2. [Streaming](#streaming)
|
||||
3. [Chat Format](#chat-format)
|
||||
4. [Generation Parameters](#generation-parameters)
|
||||
5. [Model Selection](#model-selection)
|
||||
6. [Best Practices](#best-practices)
|
||||
|
||||
## Basic Generation
|
||||
|
||||
```javascript
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
|
||||
const generator = await pipeline(
|
||||
'text-generation',
|
||||
'onnx-community/Qwen2.5-0.5B-Instruct',
|
||||
{ dtype: 'q4' }
|
||||
);
|
||||
|
||||
const result = await generator('Once upon a time', {
|
||||
max_new_tokens: 100,
|
||||
temperature: 0.7,
|
||||
});
|
||||
|
||||
console.log(result[0].generated_text);
|
||||
|
||||
// Clean up when done
|
||||
await generator.dispose();
|
||||
```
|
||||
|
||||
## Streaming
|
||||
|
||||
Stream tokens as they're generated for better UX. Once you understand streaming, you can combine it with other features like chat format.
|
||||
|
||||
### Node.js
|
||||
|
||||
```javascript
|
||||
import { pipeline, TextStreamer } from '@huggingface/transformers';
|
||||
|
||||
const generator = await pipeline(
|
||||
'text-generation',
|
||||
'onnx-community/Qwen2.5-0.5B-Instruct',
|
||||
{ dtype: 'q4' }
|
||||
);
|
||||
|
||||
const streamer = new TextStreamer(generator.tokenizer, {
|
||||
skip_prompt: true,
|
||||
skip_special_tokens: true,
|
||||
callback_function: (token) => {
|
||||
process.stdout.write(token);
|
||||
},
|
||||
});
|
||||
|
||||
await generator('Tell me a story', {
|
||||
max_new_tokens: 200,
|
||||
temperature: 0.7,
|
||||
streamer,
|
||||
});
|
||||
```
|
||||
|
||||
### Browser
|
||||
|
||||
```html
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<textarea id="prompt" placeholder="Enter prompt..."></textarea>
|
||||
<button onclick="generate()">Generate</button>
|
||||
<div id="output"></div>
|
||||
|
||||
<script type="module">
|
||||
import { pipeline, TextStreamer } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.8.1';
|
||||
|
||||
const generator = await pipeline(
|
||||
'text-generation',
|
||||
'onnx-community/Qwen2.5-0.5B-Instruct',
|
||||
{ dtype: 'q4' }
|
||||
);
|
||||
|
||||
window.generate = async function() {
|
||||
const prompt = document.getElementById('prompt').value;
|
||||
const outputDiv = document.getElementById('output');
|
||||
outputDiv.textContent = '';
|
||||
|
||||
const streamer = new TextStreamer(generator.tokenizer, {
|
||||
skip_prompt: true,
|
||||
skip_special_tokens: true,
|
||||
callback_function: (token) => {
|
||||
outputDiv.textContent += token;
|
||||
},
|
||||
});
|
||||
|
||||
await generator(prompt, {
|
||||
max_new_tokens: 200,
|
||||
temperature: 0.7,
|
||||
streamer,
|
||||
});
|
||||
};
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
```
|
||||
|
||||
### React
|
||||
|
||||
```jsx
|
||||
import { useState, useRef, useEffect } from 'react';
|
||||
import { pipeline, TextStreamer } from '@huggingface/transformers';
|
||||
|
||||
function StreamingGenerator() {
|
||||
const generatorRef = useRef(null);
|
||||
const [output, setOutput] = useState('');
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const handleGenerate = async (prompt) => {
|
||||
if (!prompt) return;
|
||||
|
||||
setLoading(true);
|
||||
setOutput('');
|
||||
|
||||
// Load model on first generate
|
||||
if (!generatorRef.current) {
|
||||
generatorRef.current = await pipeline(
|
||||
'text-generation',
|
||||
'onnx-community/Qwen2.5-0.5B-Instruct',
|
||||
{ dtype: 'q4' }
|
||||
);
|
||||
}
|
||||
|
||||
const streamer = new TextStreamer(generatorRef.current.tokenizer, {
|
||||
skip_prompt: true,
|
||||
skip_special_tokens: true,
|
||||
callback_function: (token) => {
|
||||
setOutput((prev) => prev + token);
|
||||
},
|
||||
});
|
||||
|
||||
await generatorRef.current(prompt, {
|
||||
max_new_tokens: 200,
|
||||
temperature: 0.7,
|
||||
streamer,
|
||||
});
|
||||
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
// Cleanup on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (generatorRef.current) {
|
||||
generatorRef.current.dispose();
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<button onClick={() => handleGenerate('Tell me a story')} disabled={loading}>
|
||||
{loading ? 'Generating...' : 'Generate'}
|
||||
</button>
|
||||
<div>{output}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
## Chat Format
|
||||
|
||||
Use structured messages for conversations. Works with both basic generation and streaming (just add `streamer` parameter).
|
||||
|
||||
### Single Turn
|
||||
|
||||
```javascript
|
||||
import { pipeline } from '@huggingface/transformers';
|
||||
|
||||
const generator = await pipeline(
|
||||
'text-generation',
|
||||
'onnx-community/Qwen2.5-0.5B-Instruct',
|
||||
{ dtype: 'q4' }
|
||||
);
|
||||
|
||||
const messages = [
|
||||
{ role: 'system', content: 'You are a helpful assistant.' },
|
||||
{ role: 'user', content: 'How do I create an async function?' }
|
||||
];
|
||||
|
||||
const result = await generator(messages, {
|
||||
max_new_tokens: 256,
|
||||
temperature: 0.7,
|
||||
});
|
||||
|
||||
console.log(result[0].generated_text);
|
||||
```
|
||||
|
||||
### Multi-turn Conversation
|
||||
|
||||
```javascript
|
||||
const conversation = [
|
||||
{ role: 'system', content: 'You are a helpful assistant.' },
|
||||
{ role: 'user', content: 'What is JavaScript?' },
|
||||
{ role: 'assistant', content: 'JavaScript is a programming language...' },
|
||||
{ role: 'user', content: 'Can you show an example?' }
|
||||
];
|
||||
|
||||
const result = await generator(conversation, {
|
||||
max_new_tokens: 200,
|
||||
temperature: 0.7,
|
||||
});
|
||||
|
||||
// To add streaming, just pass a streamer:
|
||||
// streamer: new TextStreamer(generator.tokenizer, {...})
|
||||
```
|
||||
|
||||
## Generation Parameters
|
||||
|
||||
### Common Parameters
|
||||
|
||||
```javascript
|
||||
await generator(prompt, {
|
||||
// Token limits
|
||||
max_new_tokens: 512, // Maximum tokens to generate
|
||||
min_new_tokens: 0, // Minimum tokens to generate
|
||||
|
||||
// Sampling
|
||||
temperature: 0.7, // Randomness (0.0-2.0)
|
||||
top_k: 50, // Consider top K tokens
|
||||
top_p: 0.95, // Nucleus sampling
|
||||
do_sample: true, // Use random sampling (false = always pick most likely token)
|
||||
|
||||
// Repetition control
|
||||
repetition_penalty: 1.0, // Penalty for repeating (1.0 = no penalty)
|
||||
no_repeat_ngram_size: 0, // Prevent repeating n-grams
|
||||
|
||||
// Streaming
|
||||
streamer: streamer, // TextStreamer instance
|
||||
});
|
||||
```
|
||||
|
||||
### Parameter Effects
|
||||
|
||||
**Temperature:**
|
||||
- Low (0.1-0.5): More focused and deterministic
|
||||
- Medium (0.6-0.9): Balanced creativity and coherence
|
||||
- High (1.0-2.0): More creative and random
|
||||
|
||||
```javascript
|
||||
// Focused output
|
||||
await generator(prompt, { temperature: 0.3, max_new_tokens: 100 });
|
||||
|
||||
// Creative output
|
||||
await generator(prompt, { temperature: 1.2, max_new_tokens: 100 });
|
||||
```
|
||||
|
||||
**Sampling Methods:**
|
||||
|
||||
```javascript
|
||||
// Greedy (deterministic)
|
||||
await generator(prompt, {
|
||||
do_sample: false,
|
||||
max_new_tokens: 100
|
||||
});
|
||||
|
||||
// Top-k sampling
|
||||
await generator(prompt, {
|
||||
top_k: 50,
|
||||
temperature: 0.7,
|
||||
max_new_tokens: 100
|
||||
});
|
||||
|
||||
// Top-p (nucleus) sampling
|
||||
await generator(prompt, {
|
||||
top_p: 0.95,
|
||||
temperature: 0.7,
|
||||
max_new_tokens: 100
|
||||
});
|
||||
```
|
||||
|
||||
## Model Selection
|
||||
|
||||
Browse available text generation models on Hugging Face Hub:
|
||||
|
||||
**https://huggingface.co/models?pipeline_tag=text-generation&library=transformers.js&sort=trending**
|
||||
|
||||
### Selection Tips
|
||||
|
||||
- **Small models (< 1B params)**: Fast, browser-friendly, use `dtype: 'q4'`
|
||||
- **Medium models (1-3B params)**: Balanced quality/speed, use `dtype: 'q4'` or `fp16`
|
||||
- **Large models (> 3B params)**: High quality, slower, best for Node.js with `dtype: 'fp16'`
|
||||
|
||||
Check model cards for:
|
||||
- Parameter count and model size
|
||||
- Supported languages
|
||||
- Benchmark scores
|
||||
- License restrictions
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Model Size**: Use quantized models (`q4`) for browsers, larger models (`fp16`) for servers
|
||||
2. **Streaming**: Use streaming for better UX - shows progress and feels responsive
|
||||
3. **Token Limits**: Set `max_new_tokens` to prevent runaway generation
|
||||
4. **Temperature**: Tune based on use case (creative: 0.8-1.2, factual: 0.3-0.7)
|
||||
5. **Memory**: Always call `dispose()` when done
|
||||
6. **Caching**: Load model once, reuse for multiple requests
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [Pipeline Options](./PIPELINE_OPTIONS.md) - Configure pipeline loading
|
||||
- [Configuration Reference](./CONFIGURATION.md) - Environment settings
|
||||
- [Code Examples](./EXAMPLES.md) - More examples for different runtimes
|
||||
- [Main Skill Guide](../SKILL.md) - Getting started guide
|
||||
@@ -19,7 +19,7 @@
|
||||
"skills": "./skills/",
|
||||
"interface": {
|
||||
"displayName": "Antigravity Awesome Skills",
|
||||
"shortDescription": "1,303 plugin-safe skills for coding, security, product, and ops workflows.",
|
||||
"shortDescription": "1,311 plugin-safe skills for coding, security, product, and ops workflows.",
|
||||
"longDescription": "Install a plugin-safe Codex distribution of Antigravity Awesome Skills. Skills that still need hardening or target-specific setup remain available in the repo but are excluded from this plugin.",
|
||||
"developerName": "sickn33 and contributors",
|
||||
"category": "Productivity",
|
||||
|
||||
@@ -1,199 +1,194 @@
|
||||
---
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/hf-cli"
|
||||
name: hugging-face-cli
|
||||
description: "The hf CLI provides direct terminal access to the Hugging Face Hub for downloading, uploading, and managing repositories, cache, and compute resources."
|
||||
risk: safe
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/hugging-face-cli"
|
||||
date_added: "2026-02-27"
|
||||
description: "Use the Hugging Face Hub CLI (`hf`) to download, upload, and manage models, datasets, and Spaces."
|
||||
risk: unknown
|
||||
---
|
||||
|
||||
# Hugging Face CLI
|
||||
Install: `curl -LsSf https://hf.co/cli/install.sh | bash -s`.
|
||||
|
||||
The `hf` CLI provides direct terminal access to the Hugging Face Hub for downloading, uploading, and managing repositories, cache, and compute resources.
|
||||
## When to Use
|
||||
|
||||
## When to Use This Skill
|
||||
Use this skill when you need the `hf` CLI for Hub authentication, downloads, uploads, repo management, or basic compute operations.
|
||||
|
||||
Use this skill when:
|
||||
- User needs to download models, datasets, or spaces
|
||||
- Uploading files to Hub repositories
|
||||
- Creating Hugging Face repositories
|
||||
- Managing local cache
|
||||
- Running compute jobs on HF infrastructure
|
||||
- Working with Hugging Face Hub authentication
|
||||
The Hugging Face Hub CLI tool `hf` is available. IMPORTANT: The `hf` command replaces the deprecated `huggingface-cli` command.
|
||||
|
||||
## Quick Command Reference
|
||||
Use `hf --help` to view available functions. Note that auth commands are now all under `hf auth` e.g. `hf auth whoami`.
|
||||
|
||||
| Task | Command |
|
||||
|------|---------|
|
||||
| Login | `hf auth login` |
|
||||
| Download model | `hf download <repo_id>` |
|
||||
| Download to folder | `hf download <repo_id> --local-dir ./path` |
|
||||
| Upload folder | `hf upload <repo_id> . .` |
|
||||
| Create repo | `hf repo create <name>` |
|
||||
| Create tag | `hf repo tag create <repo_id> <tag>` |
|
||||
| Delete files | `hf repo-files delete <repo_id> <files>` |
|
||||
| List cache | `hf cache ls` |
|
||||
| Remove from cache | `hf cache rm <repo_or_revision>` |
|
||||
| List models | `hf models ls` |
|
||||
| Get model info | `hf models info <model_id>` |
|
||||
| List datasets | `hf datasets ls` |
|
||||
| Get dataset info | `hf datasets info <dataset_id>` |
|
||||
| List spaces | `hf spaces ls` |
|
||||
| Get space info | `hf spaces info <space_id>` |
|
||||
| List endpoints | `hf endpoints ls` |
|
||||
| Run GPU job | `hf jobs run --flavor a10g-small <image> <cmd>` |
|
||||
| Environment info | `hf env` |
|
||||
Generated with `huggingface_hub v1.8.0`. Run `hf skills add --force` to regenerate.
|
||||
|
||||
## Core Commands
|
||||
## Commands
|
||||
|
||||
### Authentication
|
||||
```bash
|
||||
hf auth login # Interactive login
|
||||
hf auth login --token $HF_TOKEN # Non-interactive
|
||||
hf auth whoami # Check current user
|
||||
hf auth list # List stored tokens
|
||||
hf auth switch # Switch between tokens
|
||||
hf auth logout # Log out
|
||||
```
|
||||
- `hf download REPO_ID` — Download files from the Hub. `[--type CHOICE --revision TEXT --include TEXT --exclude TEXT --cache-dir TEXT --local-dir TEXT --force-download --dry-run --quiet --max-workers INTEGER]`
|
||||
- `hf env` — Print information about the environment.
|
||||
- `hf sync` — Sync files between local directory and a bucket. `[--delete --ignore-times --ignore-sizes --plan TEXT --apply TEXT --dry-run --include TEXT --exclude TEXT --filter-from TEXT --existing --ignore-existing --verbose --quiet]`
|
||||
- `hf upload REPO_ID` — Upload a file or a folder to the Hub. Recommended for single-commit uploads. `[--type CHOICE --revision TEXT --private --include TEXT --exclude TEXT --delete TEXT --commit-message TEXT --commit-description TEXT --create-pr --every FLOAT --quiet]`
|
||||
- `hf upload-large-folder REPO_ID LOCAL_PATH` — Upload a large folder to the Hub. Recommended for resumable uploads. `[--type CHOICE --revision TEXT --private --include TEXT --exclude TEXT --num-workers INTEGER --no-report --no-bars]`
|
||||
- `hf version` — Print information about the hf version.
|
||||
|
||||
### Download
|
||||
```bash
|
||||
hf download <repo_id> # Full repo to cache
|
||||
hf download <repo_id> file.safetensors # Specific file
|
||||
hf download <repo_id> --local-dir ./models # To local directory
|
||||
hf download <repo_id> --include "*.safetensors" # Filter by pattern
|
||||
hf download <repo_id> --repo-type dataset # Dataset
|
||||
hf download <repo_id> --revision v1.0 # Specific version
|
||||
```
|
||||
### `hf auth` — Manage authentication (login, logout, etc.).
|
||||
|
||||
### Upload
|
||||
```bash
|
||||
hf upload <repo_id> . . # Current dir to root
|
||||
hf upload <repo_id> ./models /weights # Folder to path
|
||||
hf upload <repo_id> model.safetensors # Single file
|
||||
hf upload <repo_id> . . --repo-type dataset # Dataset
|
||||
hf upload <repo_id> . . --create-pr # Create PR
|
||||
hf upload <repo_id> . . --commit-message="msg" # Custom message
|
||||
```
|
||||
- `hf auth list` — List all stored access tokens.
|
||||
- `hf auth login` — Login using a token from huggingface.co/settings/tokens. `[--add-to-git-credential --force]`
|
||||
- `hf auth logout` — Logout from a specific token. `[--token-name TEXT]`
|
||||
- `hf auth switch` — Switch between access tokens. `[--token-name TEXT --add-to-git-credential]`
|
||||
- `hf auth whoami` — Find out which huggingface.co account you are logged in as. `[--format CHOICE]`
|
||||
|
||||
### Repository Management
|
||||
```bash
|
||||
hf repo create <name> # Create model repo
|
||||
hf repo create <name> --repo-type dataset # Create dataset
|
||||
hf repo create <name> --private # Private repo
|
||||
hf repo create <name> --repo-type space --space_sdk gradio # Gradio space
|
||||
hf repo delete <repo_id> # Delete repo
|
||||
hf repo move <from_id> <to_id> # Move repo to new namespace
|
||||
hf repo settings <repo_id> --private true # Update repo settings
|
||||
hf repo list --repo-type model # List repos
|
||||
hf repo branch create <repo_id> release-v1 # Create branch
|
||||
hf repo branch delete <repo_id> release-v1 # Delete branch
|
||||
hf repo tag create <repo_id> v1.0 # Create tag
|
||||
hf repo tag list <repo_id> # List tags
|
||||
hf repo tag delete <repo_id> v1.0 # Delete tag
|
||||
```
|
||||
### `hf buckets` — Commands to interact with buckets.
|
||||
|
||||
### Delete Files from Repo
|
||||
```bash
|
||||
hf repo-files delete <repo_id> folder/ # Delete folder
|
||||
hf repo-files delete <repo_id> "*.txt" # Delete with pattern
|
||||
```
|
||||
- `hf buckets cp SRC` — Copy a single file to or from a bucket. `[--quiet]`
|
||||
- `hf buckets create BUCKET_ID` — Create a new bucket. `[--private --exist-ok --quiet]`
|
||||
- `hf buckets delete BUCKET_ID` — Delete a bucket. `[--yes --missing-ok --quiet]`
|
||||
- `hf buckets info BUCKET_ID` — Get info about a bucket. `[--quiet]`
|
||||
- `hf buckets list` — List buckets or files in a bucket. `[--human-readable --tree --recursive --format CHOICE --quiet]`
|
||||
- `hf buckets move FROM_ID TO_ID` — Move (rename) a bucket to a new name or namespace.
|
||||
- `hf buckets remove ARGUMENT` — Remove files from a bucket. `[--recursive --yes --dry-run --include TEXT --exclude TEXT --quiet]`
|
||||
- `hf buckets sync` — Sync files between local directory and a bucket. `[--delete --ignore-times --ignore-sizes --plan TEXT --apply TEXT --dry-run --include TEXT --exclude TEXT --filter-from TEXT --existing --ignore-existing --verbose --quiet]`
|
||||
|
||||
### Cache Management
|
||||
```bash
|
||||
hf cache ls # List cached repos
|
||||
hf cache ls --revisions # Include individual revisions
|
||||
hf cache rm model/gpt2 # Remove cached repo
|
||||
hf cache rm <revision_hash> # Remove cached revision
|
||||
hf cache prune # Remove detached revisions
|
||||
hf cache verify gpt2 # Verify checksums from cache
|
||||
```
|
||||
### `hf cache` — Manage local cache directory.
|
||||
|
||||
### Browse Hub
|
||||
```bash
|
||||
# Models
|
||||
hf models ls # List top trending models
|
||||
hf models ls --search "MiniMax" --author MiniMaxAI # Search models
|
||||
hf models ls --filter "text-generation" --limit 20 # Filter by task
|
||||
hf models info MiniMaxAI/MiniMax-M2.1 # Get model info
|
||||
- `hf cache list` — List cached repositories or revisions. `[--cache-dir TEXT --revisions --filter TEXT --format CHOICE --quiet --sort CHOICE --limit INTEGER]`
|
||||
- `hf cache prune` — Remove detached revisions from the cache. `[--cache-dir TEXT --yes --dry-run]`
|
||||
- `hf cache rm TARGETS` — Remove cached repositories or revisions. `[--cache-dir TEXT --yes --dry-run]`
|
||||
- `hf cache verify REPO_ID` — Verify checksums for a single repo revision from cache or a local directory. `[--type CHOICE --revision TEXT --cache-dir TEXT --local-dir TEXT --fail-on-missing-files --fail-on-extra-files]`
|
||||
|
||||
# Datasets
|
||||
hf datasets ls # List top trending datasets
|
||||
hf datasets ls --search "finepdfs" --sort downloads # Search datasets
|
||||
hf datasets info HuggingFaceFW/finepdfs # Get dataset info
|
||||
### `hf collections` — Interact with collections on the Hub.
|
||||
|
||||
# Spaces
|
||||
hf spaces ls # List top trending spaces
|
||||
hf spaces ls --filter "3d" --limit 10 # Filter by 3D modeling spaces
|
||||
hf spaces info enzostvs/deepsite # Get space info
|
||||
```
|
||||
- `hf collections add-item COLLECTION_SLUG ITEM_ID ITEM_TYPE` — Add an item to a collection. `[--note TEXT --exists-ok]`
|
||||
- `hf collections create TITLE` — Create a new collection on the Hub. `[--namespace TEXT --description TEXT --private --exists-ok]`
|
||||
- `hf collections delete COLLECTION_SLUG` — Delete a collection from the Hub. `[--missing-ok]`
|
||||
- `hf collections delete-item COLLECTION_SLUG ITEM_OBJECT_ID` — Delete an item from a collection. `[--missing-ok]`
|
||||
- `hf collections info COLLECTION_SLUG` — Get info about a collection on the Hub. Output is in JSON format.
|
||||
- `hf collections list` — List collections on the Hub. `[--owner TEXT --item TEXT --sort CHOICE --limit INTEGER --format CHOICE --quiet]`
|
||||
- `hf collections update COLLECTION_SLUG` — Update a collection's metadata on the Hub. `[--title TEXT --description TEXT --position INTEGER --private --theme TEXT]`
|
||||
- `hf collections update-item COLLECTION_SLUG ITEM_OBJECT_ID` — Update an item in a collection. `[--note TEXT --position INTEGER]`
|
||||
|
||||
### Jobs (Cloud Compute)
|
||||
```bash
|
||||
hf jobs run python:3.12 python script.py # Run on CPU
|
||||
hf jobs run --flavor a10g-small <image> <cmd> # Run on GPU
|
||||
hf jobs run --secrets HF_TOKEN <image> <cmd> # With HF token
|
||||
hf jobs ps # List jobs
|
||||
hf jobs logs <job_id> # View logs
|
||||
hf jobs cancel <job_id> # Cancel job
|
||||
```
|
||||
### `hf datasets` — Interact with datasets on the Hub.
|
||||
|
||||
### Inference Endpoints
|
||||
```bash
|
||||
hf endpoints ls # List endpoints
|
||||
hf endpoints deploy my-endpoint \
|
||||
--repo openai/gpt-oss-120b \
|
||||
--framework vllm \
|
||||
--accelerator gpu \
|
||||
--instance-size x4 \
|
||||
--instance-type nvidia-a10g \
|
||||
--region us-east-1 \
|
||||
--vendor aws
|
||||
hf endpoints describe my-endpoint # Show endpoint details
|
||||
hf endpoints pause my-endpoint # Pause endpoint
|
||||
hf endpoints resume my-endpoint # Resume endpoint
|
||||
hf endpoints scale-to-zero my-endpoint # Scale to zero
|
||||
hf endpoints delete my-endpoint --yes # Delete endpoint
|
||||
```
|
||||
**GPU Flavors:** `cpu-basic`, `cpu-upgrade`, `cpu-xl`, `t4-small`, `t4-medium`, `l4x1`, `l4x4`, `l40sx1`, `l40sx4`, `l40sx8`, `a10g-small`, `a10g-large`, `a10g-largex2`, `a10g-largex4`, `a100-large`, `h100`, `h100x8`
|
||||
- `hf datasets info DATASET_ID` — Get info about a dataset on the Hub. Output is in JSON format. `[--revision TEXT --expand TEXT]`
|
||||
- `hf datasets list` — List datasets on the Hub. `[--search TEXT --author TEXT --filter TEXT --sort CHOICE --limit INTEGER --expand TEXT --format CHOICE --quiet]`
|
||||
- `hf datasets parquet DATASET_ID` — List parquet file URLs available for a dataset. `[--subset TEXT --split TEXT --format CHOICE --quiet]`
|
||||
- `hf datasets sql SQL` — Execute a raw SQL query with DuckDB against dataset parquet URLs. `[--format CHOICE]`
|
||||
|
||||
## Common Patterns
|
||||
### `hf discussions` — Manage discussions and pull requests on the Hub.
|
||||
|
||||
### Download and Use Model Locally
|
||||
```bash
|
||||
# Download to local directory for deployment
|
||||
hf download meta-llama/Llama-3.2-1B-Instruct --local-dir ./model
|
||||
- `hf discussions close REPO_ID NUM` — Close a discussion or pull request. `[--comment TEXT --yes --type CHOICE]`
|
||||
- `hf discussions comment REPO_ID NUM` — Comment on a discussion or pull request. `[--body TEXT --body-file PATH --type CHOICE]`
|
||||
- `hf discussions create REPO_ID --title TEXT` — Create a new discussion or pull request on a repo. `[--body TEXT --body-file PATH --pull-request --type CHOICE]`
|
||||
- `hf discussions diff REPO_ID NUM` — Show the diff of a pull request. `[--type CHOICE]`
|
||||
- `hf discussions info REPO_ID NUM` — Get info about a discussion or pull request. `[--comments --diff --no-color --type CHOICE --format CHOICE]`
|
||||
- `hf discussions list REPO_ID` — List discussions and pull requests on a repo. `[--status CHOICE --kind CHOICE --author TEXT --limit INTEGER --type CHOICE --format CHOICE --quiet]`
|
||||
- `hf discussions merge REPO_ID NUM` — Merge a pull request. `[--comment TEXT --yes --type CHOICE]`
|
||||
- `hf discussions rename REPO_ID NUM NEW_TITLE` — Rename a discussion or pull request. `[--type CHOICE]`
|
||||
- `hf discussions reopen REPO_ID NUM` — Reopen a closed discussion or pull request. `[--comment TEXT --yes --type CHOICE]`
|
||||
|
||||
# Or use cache and get path
|
||||
MODEL_PATH=$(hf download meta-llama/Llama-3.2-1B-Instruct --quiet)
|
||||
```
|
||||
### `hf endpoints` — Manage Hugging Face Inference Endpoints.
|
||||
|
||||
### Publish Model/Dataset
|
||||
```bash
|
||||
hf repo create my-username/my-model --private
|
||||
hf upload my-username/my-model ./output . --commit-message="Initial release"
|
||||
hf repo tag create my-username/my-model v1.0
|
||||
```
|
||||
- `hf endpoints catalog deploy --repo TEXT` — Deploy an Inference Endpoint from the Model Catalog. `[--name TEXT --accelerator TEXT --namespace TEXT]`
|
||||
- `hf endpoints catalog list` — List available Catalog models.
|
||||
- `hf endpoints delete NAME` — Delete an Inference Endpoint permanently. `[--namespace TEXT --yes]`
|
||||
- `hf endpoints deploy NAME --repo TEXT --framework TEXT --accelerator TEXT --instance-size TEXT --instance-type TEXT --region TEXT --vendor TEXT` — Deploy an Inference Endpoint from a Hub repository. `[--namespace TEXT --task TEXT --min-replica INTEGER --max-replica INTEGER --scale-to-zero-timeout INTEGER --scaling-metric CHOICE --scaling-threshold FLOAT]`
|
||||
- `hf endpoints describe NAME` — Get information about an existing endpoint. `[--namespace TEXT]`
|
||||
- `hf endpoints list` — Lists all Inference Endpoints for the given namespace. `[--namespace TEXT --format CHOICE --quiet]`
|
||||
- `hf endpoints pause NAME` — Pause an Inference Endpoint. `[--namespace TEXT]`
|
||||
- `hf endpoints resume NAME` — Resume an Inference Endpoint. `[--namespace TEXT --fail-if-already-running]`
|
||||
- `hf endpoints scale-to-zero NAME` — Scale an Inference Endpoint to zero. `[--namespace TEXT]`
|
||||
- `hf endpoints update NAME` — Update an existing endpoint. `[--namespace TEXT --repo TEXT --accelerator TEXT --instance-size TEXT --instance-type TEXT --framework TEXT --revision TEXT --task TEXT --min-replica INTEGER --max-replica INTEGER --scale-to-zero-timeout INTEGER --scaling-metric CHOICE --scaling-threshold FLOAT]`
|
||||
|
||||
### Sync Space with Local
|
||||
```bash
|
||||
hf upload my-username/my-space . . --repo-type space \
|
||||
--exclude="logs/*" --delete="*" --commit-message="Sync"
|
||||
```
|
||||
### `hf extensions` — Manage hf CLI extensions.
|
||||
|
||||
### Check Cache Usage
|
||||
```bash
|
||||
hf cache ls # See all cached repos and sizes
|
||||
hf cache rm model/gpt2 # Remove a repo from cache
|
||||
```
|
||||
- `hf extensions exec NAME` — Execute an installed extension.
|
||||
- `hf extensions install REPO_ID` — Install an extension from a public GitHub repository. `[--force]`
|
||||
- `hf extensions list` — List installed extension commands. `[--format CHOICE --quiet]`
|
||||
- `hf extensions remove NAME` — Remove an installed extension.
|
||||
- `hf extensions search` — Search extensions available on GitHub (tagged with 'hf-extension' topic). `[--format CHOICE --quiet]`
|
||||
|
||||
## Key Options
|
||||
### `hf jobs` — Run and manage Jobs on the Hub.
|
||||
|
||||
- `--repo-type`: `model` (default), `dataset`, `space`
|
||||
- `--revision`: Branch, tag, or commit hash
|
||||
- `--token`: Override authentication
|
||||
- `--quiet`: Output only essential info (paths/URLs)
|
||||
- `hf jobs cancel JOB_ID` — Cancel a Job `[--namespace TEXT]`
|
||||
- `hf jobs hardware` — List available hardware options for Jobs
|
||||
- `hf jobs inspect JOB_IDS` — Display detailed information on one or more Jobs `[--namespace TEXT]`
|
||||
- `hf jobs logs JOB_ID` — Fetch the logs of a Job. `[--follow --tail INTEGER --namespace TEXT]`
|
||||
- `hf jobs ps` — List Jobs. `[--all --namespace TEXT --filter TEXT --format TEXT --quiet]`
|
||||
- `hf jobs run IMAGE COMMAND` — Run a Job. `[--env TEXT --secrets TEXT --label TEXT --volume TEXT --env-file TEXT --secrets-file TEXT --flavor CHOICE --timeout TEXT --detach --namespace TEXT]`
|
||||
- `hf jobs scheduled delete SCHEDULED_JOB_ID` — Delete a scheduled Job. `[--namespace TEXT]`
|
||||
- `hf jobs scheduled inspect SCHEDULED_JOB_IDS` — Display detailed information on one or more scheduled Jobs `[--namespace TEXT]`
|
||||
- `hf jobs scheduled ps` — List scheduled Jobs `[--all --namespace TEXT --filter TEXT --format TEXT --quiet]`
|
||||
- `hf jobs scheduled resume SCHEDULED_JOB_ID` — Resume (unpause) a scheduled Job. `[--namespace TEXT]`
|
||||
- `hf jobs scheduled run SCHEDULE IMAGE COMMAND` — Schedule a Job. `[--suspend --concurrency --env TEXT --secrets TEXT --label TEXT --volume TEXT --env-file TEXT --secrets-file TEXT --flavor CHOICE --timeout TEXT --namespace TEXT]`
|
||||
- `hf jobs scheduled suspend SCHEDULED_JOB_ID` — Suspend (pause) a scheduled Job. `[--namespace TEXT]`
|
||||
- `hf jobs scheduled uv run SCHEDULE SCRIPT` — Run a UV script (local file or URL) on HF infrastructure `[--suspend --concurrency --image TEXT --flavor CHOICE --env TEXT --secrets TEXT --label TEXT --volume TEXT --env-file TEXT --secrets-file TEXT --timeout TEXT --namespace TEXT --with TEXT --python TEXT]`
|
||||
- `hf jobs stats` — Fetch the resource usage statistics and metrics of Jobs `[--namespace TEXT]`
|
||||
- `hf jobs uv run SCRIPT` — Run a UV script (local file or URL) on HF infrastructure `[--image TEXT --flavor CHOICE --env TEXT --secrets TEXT --label TEXT --volume TEXT --env-file TEXT --secrets-file TEXT --timeout TEXT --detach --namespace TEXT --with TEXT --python TEXT]`
|
||||
|
||||
## References
|
||||
### `hf models` — Interact with models on the Hub.
|
||||
|
||||
- **Complete command reference**: See references/commands.md
|
||||
- **Workflow examples**: See references/examples.md
|
||||
- `hf models info MODEL_ID` — Get info about a model on the Hub. Output is in JSON format. `[--revision TEXT --expand TEXT]`
|
||||
- `hf models list` — List models on the Hub. `[--search TEXT --author TEXT --filter TEXT --num-parameters TEXT --sort CHOICE --limit INTEGER --expand TEXT --format CHOICE --quiet]`
|
||||
|
||||
### `hf papers` — Interact with papers on the Hub.
|
||||
|
||||
- `hf papers info PAPER_ID` — Get info about a paper on the Hub. Output is in JSON format.
|
||||
- `hf papers list` — List daily papers on the Hub. `[--date TEXT --week TEXT --month TEXT --submitter TEXT --sort CHOICE --limit INTEGER --format CHOICE --quiet]`
|
||||
- `hf papers read PAPER_ID` — Read a paper as markdown.
|
||||
- `hf papers search QUERY` — Search papers on the Hub. `[--limit INTEGER --format CHOICE --quiet]`
|
||||
|
||||
### `hf repos` — Manage repos on the Hub.
|
||||
|
||||
- `hf repos branch create REPO_ID BRANCH` — Create a new branch for a repo on the Hub. `[--revision TEXT --type CHOICE --exist-ok]`
|
||||
- `hf repos branch delete REPO_ID BRANCH` — Delete a branch from a repo on the Hub. `[--type CHOICE]`
|
||||
- `hf repos create REPO_ID` — Create a new repo on the Hub. `[--type CHOICE --space-sdk TEXT --private --public --protected --exist-ok --resource-group-id TEXT --flavor TEXT --storage TEXT --sleep-time INTEGER --secrets TEXT --secrets-file TEXT --env TEXT --env-file TEXT]`
|
||||
- `hf repos delete REPO_ID` — Delete a repo from the Hub. This is an irreversible operation. `[--type CHOICE --missing-ok]`
|
||||
- `hf repos delete-files REPO_ID PATTERNS` — Delete files from a repo on the Hub. `[--type CHOICE --revision TEXT --commit-message TEXT --commit-description TEXT --create-pr]`
|
||||
- `hf repos duplicate FROM_ID` — Duplicate a repo on the Hub (model, dataset, or Space). `[--type CHOICE --private --public --protected --exist-ok --flavor TEXT --storage TEXT --sleep-time INTEGER --secrets TEXT --secrets-file TEXT --env TEXT --env-file TEXT]`
|
||||
- `hf repos move FROM_ID TO_ID` — Move a repository from a namespace to another namespace. `[--type CHOICE]`
|
||||
- `hf repos settings REPO_ID` — Update the settings of a repository. `[--gated CHOICE --private --public --protected --type CHOICE]`
|
||||
- `hf repos tag create REPO_ID TAG` — Create a tag for a repo. `[--message TEXT --revision TEXT --type CHOICE]`
|
||||
- `hf repos tag delete REPO_ID TAG` — Delete a tag for a repo. `[--yes --type CHOICE]`
|
||||
- `hf repos tag list REPO_ID` — List tags for a repo. `[--type CHOICE]`
|
||||
|
||||
### `hf skills` — Manage skills for AI assistants.
|
||||
|
||||
- `hf skills add` — Download a skill and install it for an AI assistant. `[--claude --codex --cursor --opencode --global --dest PATH --force]`
|
||||
- `hf skills preview` — Print the generated SKILL.md to stdout.
|
||||
|
||||
### `hf spaces` — Interact with spaces on the Hub.
|
||||
|
||||
- `hf spaces dev-mode SPACE_ID` — Enable or disable dev mode on a Space. `[--stop]`
|
||||
- `hf spaces hot-reload SPACE_ID` — Hot-reload any Python file of a Space without a full rebuild + restart. `[--local-file TEXT --skip-checks --skip-summary]`
|
||||
- `hf spaces info SPACE_ID` — Get info about a space on the Hub. Output is in JSON format. `[--revision TEXT --expand TEXT]`
|
||||
- `hf spaces list` — List spaces on the Hub. `[--search TEXT --author TEXT --filter TEXT --sort CHOICE --limit INTEGER --expand TEXT --format CHOICE --quiet]`
|
||||
|
||||
### `hf webhooks` — Manage webhooks on the Hub.
|
||||
|
||||
- `hf webhooks create --watch TEXT` — Create a new webhook. `[--url TEXT --job-id TEXT --domain CHOICE --secret TEXT]`
|
||||
- `hf webhooks delete WEBHOOK_ID` — Delete a webhook permanently. `[--yes]`
|
||||
- `hf webhooks disable WEBHOOK_ID` — Disable an active webhook.
|
||||
- `hf webhooks enable WEBHOOK_ID` — Enable a disabled webhook.
|
||||
- `hf webhooks info WEBHOOK_ID` — Show full details for a single webhook as JSON.
|
||||
- `hf webhooks list` — List all webhooks for the current user. `[--format CHOICE --quiet]`
|
||||
- `hf webhooks update WEBHOOK_ID` — Update an existing webhook. Only provided options are changed. `[--url TEXT --watch TEXT --domain CHOICE --secret TEXT]`
|
||||
|
||||
## Common options
|
||||
|
||||
- `--format` — Output format: `--format json` (or `--json`) or `--format table` (default).
|
||||
- `-q / --quiet` — Minimal output.
|
||||
- `--revision` — Git revision id which can be a branch name, a tag, or a commit hash.
|
||||
- `--token` — Use a User Access Token. Prefer setting `HF_TOKEN` env var instead of passing `--token`.
|
||||
- `--type` — The type of repository (model, dataset, or space).
|
||||
|
||||
## Mounting repos as local filesystems
|
||||
|
||||
To mount Hub repositories or buckets as local filesystems — no download, no copy, no waiting — use `hf-mount`. Files are fetched on demand. GitHub: https://github.com/huggingface/hf-mount
|
||||
|
||||
Install: `curl -fsSL https://raw.githubusercontent.com/huggingface/hf-mount/main/install.sh | sh`
|
||||
|
||||
Some command examples:
|
||||
- `hf-mount start repo openai-community/gpt2 /tmp/gpt2` — mount a repo (read-only)
|
||||
- `hf-mount start --hf-token $HF_TOKEN bucket myuser/my-bucket /tmp/data` — mount a bucket (read-write)
|
||||
- `hf-mount status` / `hf-mount stop /tmp/data` — list or unmount
|
||||
|
||||
## Tips
|
||||
|
||||
- Use `hf <command> --help` for full options, descriptions, usage, and real-world examples
|
||||
- Authenticate with `HF_TOKEN` env var (recommended) or with `--token`
|
||||
|
||||
@@ -0,0 +1,213 @@
|
||||
---
|
||||
source: "https://github.com/huggingface/skills/tree/main/skills/huggingface-community-evals"
|
||||
name: hugging-face-community-evals
|
||||
description: Run local evaluations for Hugging Face Hub models with inspect-ai or lighteval.
|
||||
risk: unknown
|
||||
---
|
||||
|
||||
# Overview
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill for local model evaluation, backend selection, and GPU smoke tests outside the Hugging Face Jobs workflow.
|
||||
|
||||
This skill is for **running evaluations against models on the Hugging Face Hub on local hardware**.
|
||||
|
||||
It covers:
|
||||
- `inspect-ai` with local inference
|
||||
- `lighteval` with local inference
|
||||
- choosing between `vllm`, Hugging Face Transformers, and `accelerate`
|
||||
- smoke tests, task selection, and backend fallback strategy
|
||||
|
||||
It does **not** cover:
|
||||
- Hugging Face Jobs orchestration
|
||||
- model-card or `model-index` edits
|
||||
- README table extraction
|
||||
- Artificial Analysis imports
|
||||
- `.eval_results` generation or publishing
|
||||
- PR creation or community-evals automation
|
||||
|
||||
If the user wants to **run the same eval remotely on Hugging Face Jobs**, hand off to the `hugging-face-jobs` skill and pass it one of the local scripts in this skill.
|
||||
|
||||
If the user wants to **publish results into the community evals workflow**, stop after generating the evaluation run and hand off that publishing step to `~/code/community-evals`.
|
||||
|
||||
> All paths below are relative to the directory containing this `SKILL.md`.
|
||||
|
||||
# When To Use Which Script
|
||||
|
||||
| Use case | Script |
|
||||
|---|---|
|
||||
| Local `inspect-ai` eval on a Hub model via inference providers | `scripts/inspect_eval_uv.py` |
|
||||
| Local GPU eval with `inspect-ai` using `vllm` or Transformers | `scripts/inspect_vllm_uv.py` |
|
||||
| Local GPU eval with `lighteval` using `vllm` or `accelerate` | `scripts/lighteval_vllm_uv.py` |
|
||||
| Extra command patterns | `examples/USAGE_EXAMPLES.md` |
|
||||
|
||||
# Prerequisites
|
||||
|
||||
- Prefer `uv run` for local execution.
|
||||
- Set `HF_TOKEN` for gated/private models.
|
||||
- For local GPU runs, verify GPU access before starting:
|
||||
|
||||
```bash
|
||||
uv --version
|
||||
printenv HF_TOKEN >/dev/null
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
If `nvidia-smi` is unavailable, either:
|
||||
- use `scripts/inspect_eval_uv.py` for lighter provider-backed evaluation, or
|
||||
- hand off to the `hugging-face-jobs` skill if the user wants remote compute.
|
||||
|
||||
# Core Workflow
|
||||
|
||||
1. Choose the evaluation framework.
|
||||
- Use `inspect-ai` when you want explicit task control and inspect-native flows.
|
||||
- Use `lighteval` when the benchmark is naturally expressed as a lighteval task string, especially leaderboard-style tasks.
|
||||
2. Choose the inference backend.
|
||||
- Prefer `vllm` for throughput on supported architectures.
|
||||
- Use Hugging Face Transformers (`--backend hf`) or `accelerate` as compatibility fallbacks.
|
||||
3. Start with a smoke test.
|
||||
- `inspect-ai`: add `--limit 10` or similar.
|
||||
- `lighteval`: add `--max-samples 10`.
|
||||
4. Scale up only after the smoke test passes.
|
||||
5. If the user wants remote execution, hand off to `hugging-face-jobs` with the same script + args.
|
||||
|
||||
# Quick Start
|
||||
|
||||
## Option A: inspect-ai with local inference providers path
|
||||
|
||||
Best when the model is already supported by Hugging Face Inference Providers and you want the lowest local setup overhead.
|
||||
|
||||
```bash
|
||||
uv run scripts/inspect_eval_uv.py \
|
||||
--model meta-llama/Llama-3.2-1B \
|
||||
--task mmlu \
|
||||
--limit 20
|
||||
```
|
||||
|
||||
Use this path when:
|
||||
- you want a quick local smoke test
|
||||
- you do not need direct GPU control
|
||||
- the task already exists in `inspect-evals`
|
||||
|
||||
## Option B: inspect-ai on Local GPU
|
||||
|
||||
Best when you need to load the Hub model directly, use `vllm`, or fall back to Transformers for unsupported architectures.
|
||||
|
||||
Local GPU:
|
||||
|
||||
```bash
|
||||
uv run scripts/inspect_vllm_uv.py \
|
||||
--model meta-llama/Llama-3.2-1B \
|
||||
--task gsm8k \
|
||||
--limit 20
|
||||
```
|
||||
|
||||
Transformers fallback:
|
||||
|
||||
```bash
|
||||
uv run scripts/inspect_vllm_uv.py \
|
||||
--model microsoft/phi-2 \
|
||||
--task mmlu \
|
||||
--backend hf \
|
||||
--trust-remote-code \
|
||||
--limit 20
|
||||
```
|
||||
|
||||
## Option C: lighteval on Local GPU
|
||||
|
||||
Best when the task is naturally expressed as a `lighteval` task string, especially Open LLM Leaderboard style benchmarks.
|
||||
|
||||
Local GPU:
|
||||
|
||||
```bash
|
||||
uv run scripts/lighteval_vllm_uv.py \
|
||||
--model meta-llama/Llama-3.2-3B-Instruct \
|
||||
--tasks "leaderboard|mmlu|5,leaderboard|gsm8k|5" \
|
||||
--max-samples 20 \
|
||||
--use-chat-template
|
||||
```
|
||||
|
||||
`accelerate` fallback:
|
||||
|
||||
```bash
|
||||
uv run scripts/lighteval_vllm_uv.py \
|
||||
--model microsoft/phi-2 \
|
||||
--tasks "leaderboard|mmlu|5" \
|
||||
--backend accelerate \
|
||||
--trust-remote-code \
|
||||
--max-samples 20
|
||||
```
|
||||
|
||||
# Remote Execution Boundary
|
||||
|
||||
This skill intentionally stops at **local execution and backend selection**.
|
||||
|
||||
If the user wants to:
|
||||
- run these scripts on Hugging Face Jobs
|
||||
- pick remote hardware
|
||||
- pass secrets to remote jobs
|
||||
- schedule recurring runs
|
||||
- inspect / cancel / monitor jobs
|
||||
|
||||
then switch to the **`hugging-face-jobs`** skill and pass it one of these scripts plus the chosen arguments.
|
||||
|
||||
# Task Selection
|
||||
|
||||
`inspect-ai` examples:
|
||||
- `mmlu`
|
||||
- `gsm8k`
|
||||
- `hellaswag`
|
||||
- `arc_challenge`
|
||||
- `truthfulqa`
|
||||
- `winogrande`
|
||||
- `humaneval`
|
||||
|
||||
`lighteval` task strings use `suite|task|num_fewshot`:
|
||||
- `leaderboard|mmlu|5`
|
||||
- `leaderboard|gsm8k|5`
|
||||
- `leaderboard|arc_challenge|25`
|
||||
- `lighteval|hellaswag|0`
|
||||
|
||||
Multiple `lighteval` tasks can be comma-separated in `--tasks`.
|
||||
|
||||
# Backend Selection
|
||||
|
||||
- Prefer `inspect_vllm_uv.py --backend vllm` for fast GPU inference on supported architectures.
|
||||
- Use `inspect_vllm_uv.py --backend hf` when `vllm` does not support the model.
|
||||
- Prefer `lighteval_vllm_uv.py --backend vllm` for throughput on supported models.
|
||||
- Use `lighteval_vllm_uv.py --backend accelerate` as the compatibility fallback.
|
||||
- Use `inspect_eval_uv.py` when Inference Providers already cover the model and you do not need direct GPU control.
|
||||
|
||||
# Hardware Guidance
|
||||
|
||||
| Model size | Suggested local hardware |
|
||||
|---|---|
|
||||
| `< 3B` | consumer GPU / Apple Silicon / small dev GPU |
|
||||
| `3B - 13B` | stronger local GPU |
|
||||
| `13B+` | high-memory local GPU or hand off to `hugging-face-jobs` |
|
||||
|
||||
For smoke tests, prefer cheaper local runs plus `--limit` or `--max-samples`.
|
||||
|
||||
# Troubleshooting
|
||||
|
||||
- CUDA or vLLM OOM:
|
||||
- reduce `--batch-size`
|
||||
- reduce `--gpu-memory-utilization`
|
||||
- switch to a smaller model for the smoke test
|
||||
- if necessary, hand off to `hugging-face-jobs`
|
||||
- Model unsupported by `vllm`:
|
||||
- switch to `--backend hf` for `inspect-ai`
|
||||
- switch to `--backend accelerate` for `lighteval`
|
||||
- Gated/private repo access fails:
|
||||
- verify `HF_TOKEN`
|
||||
- Custom model code required:
|
||||
- add `--trust-remote-code`
|
||||
|
||||
# Examples
|
||||
|
||||
See:
|
||||
- `examples/USAGE_EXAMPLES.md` for local command patterns
|
||||
- `scripts/inspect_eval_uv.py`
|
||||
- `scripts/inspect_vllm_uv.py`
|
||||
- `scripts/lighteval_vllm_uv.py`
|
||||
@@ -0,0 +1,3 @@
|
||||
# Hugging Face Token (required for gated/private models)
|
||||
# Get your token at: https://huggingface.co/settings/tokens
|
||||
HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
@@ -0,0 +1,101 @@
|
||||
# Usage Examples
|
||||
|
||||
This document provides practical examples for **running evaluations locally** against Hugging Face Hub models.
|
||||
|
||||
## What this skill covers
|
||||
|
||||
- `inspect-ai` local runs
|
||||
- `inspect-ai` with `vllm` or Transformers backends
|
||||
- `lighteval` local runs with `vllm` or `accelerate`
|
||||
- smoke tests and backend fallback patterns
|
||||
|
||||
## What this skill does NOT cover
|
||||
|
||||
- `model-index`
|
||||
- `.eval_results`
|
||||
- community eval publication workflows
|
||||
- model-card PR creation
|
||||
- Hugging Face Jobs orchestration
|
||||
|
||||
If you want to run these same scripts remotely, use the `hugging-face-jobs` skill and pass one of the scripts in `scripts/`.
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
cd skills/hugging-face-evaluation
|
||||
export HF_TOKEN=hf_xxx
|
||||
uv --version
|
||||
```
|
||||
|
||||
For local GPU runs:
|
||||
|
||||
```bash
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
## inspect-ai examples
|
||||
|
||||
### Quick smoke test
|
||||
|
||||
```bash
|
||||
uv run scripts/inspect_eval_uv.py \
|
||||
--model meta-llama/Llama-3.2-1B \
|
||||
--task mmlu \
|
||||
--limit 10
|
||||
```
|
||||
|
||||
### Local GPU with vLLM
|
||||
|
||||
```bash
|
||||
uv run scripts/inspect_vllm_uv.py \
|
||||
--model meta-llama/Llama-3.2-8B-Instruct \
|
||||
--task gsm8k \
|
||||
--limit 20
|
||||
```
|
||||
|
||||
### Transformers fallback
|
||||
|
||||
```bash
|
||||
uv run scripts/inspect_vllm_uv.py \
|
||||
--model microsoft/phi-2 \
|
||||
--task mmlu \
|
||||
--backend hf \
|
||||
--trust-remote-code \
|
||||
--limit 20
|
||||
```
|
||||
|
||||
## lighteval examples
|
||||
|
||||
### Single task
|
||||
|
||||
```bash
|
||||
uv run scripts/lighteval_vllm_uv.py \
|
||||
--model meta-llama/Llama-3.2-3B-Instruct \
|
||||
--tasks "leaderboard|mmlu|5" \
|
||||
--max-samples 20
|
||||
```
|
||||
|
||||
### Multiple tasks
|
||||
|
||||
```bash
|
||||
uv run scripts/lighteval_vllm_uv.py \
|
||||
--model meta-llama/Llama-3.2-3B-Instruct \
|
||||
--tasks "leaderboard|mmlu|5,leaderboard|gsm8k|5" \
|
||||
--max-samples 20 \
|
||||
--use-chat-template
|
||||
```
|
||||
|
||||
### accelerate fallback
|
||||
|
||||
```bash
|
||||
uv run scripts/lighteval_vllm_uv.py \
|
||||
--model microsoft/phi-2 \
|
||||
--tasks "leaderboard|mmlu|5" \
|
||||
--backend accelerate \
|
||||
--trust-remote-code \
|
||||
--max-samples 20
|
||||
```
|
||||
|
||||
## Hand-off to Hugging Face Jobs
|
||||
|
||||
When local hardware is not enough, switch to the `hugging-face-jobs` skill and run one of these scripts remotely. Keep the script path and args; move the orchestration there.
|
||||
@@ -0,0 +1,104 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "inspect-ai>=0.3.0",
|
||||
# "inspect-evals",
|
||||
# "openai",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Entry point script for running inspect-ai evaluations against Hugging Face inference providers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _inspect_evals_tasks_root() -> Optional[Path]:
|
||||
"""Return the installed inspect_evals package path if available."""
|
||||
try:
|
||||
import inspect_evals
|
||||
|
||||
return Path(inspect_evals.__file__).parent
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_task(task: str) -> str:
|
||||
"""Allow lighteval-style `suite|task|shots` strings by keeping the task name."""
|
||||
if "|" in task:
|
||||
parts = task.split("|")
|
||||
if len(parts) >= 2 and parts[1]:
|
||||
return parts[1]
|
||||
return task
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Inspect-ai job runner")
|
||||
parser.add_argument("--model", required=True, help="Model ID on Hugging Face Hub")
|
||||
parser.add_argument("--task", required=True, help="inspect-ai task to execute")
|
||||
parser.add_argument("--limit", type=int, default=None, help="Limit number of samples to evaluate")
|
||||
parser.add_argument(
|
||||
"--tasks-root",
|
||||
default=None,
|
||||
help="Optional path to inspect task files. Defaults to the installed inspect_evals package.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sandbox",
|
||||
default="local",
|
||||
help="Sandbox backend to use (default: local for HF jobs without Docker).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Ensure downstream libraries can read the token passed as a secret
|
||||
hf_token = os.getenv("HF_TOKEN")
|
||||
if hf_token:
|
||||
os.environ.setdefault("HUGGING_FACE_HUB_TOKEN", hf_token)
|
||||
os.environ.setdefault("HF_HUB_TOKEN", hf_token)
|
||||
|
||||
task = _normalize_task(args.task)
|
||||
tasks_root = Path(args.tasks_root) if args.tasks_root else _inspect_evals_tasks_root()
|
||||
if tasks_root and not tasks_root.exists():
|
||||
tasks_root = None
|
||||
|
||||
cmd = [
|
||||
"inspect",
|
||||
"eval",
|
||||
task,
|
||||
"--model",
|
||||
f"hf-inference-providers/{args.model}",
|
||||
"--log-level",
|
||||
"info",
|
||||
# Reduce batch size to avoid OOM errors (default is 32)
|
||||
"--max-connections",
|
||||
"1",
|
||||
# Set a small positive temperature (HF doesn't allow temperature=0)
|
||||
"--temperature",
|
||||
"0.001",
|
||||
]
|
||||
|
||||
if args.sandbox:
|
||||
cmd.extend(["--sandbox", args.sandbox])
|
||||
|
||||
if args.limit:
|
||||
cmd.extend(["--limit", str(args.limit)])
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True, cwd=tasks_root)
|
||||
print("Evaluation complete.")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
location = f" (cwd={tasks_root})" if tasks_root else ""
|
||||
print(f"Evaluation failed with exit code {exc.returncode}{location}", file=sys.stderr)
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user