diff --git a/.changeset/thin-peaches-play.md b/.changeset/thin-peaches-play.md new file mode 100644 index 0000000..cd1332c --- /dev/null +++ b/.changeset/thin-peaches-play.md @@ -0,0 +1,5 @@ +--- +"shadcn-zod-form": patch +--- + +feat: add array fields support diff --git a/src/commands/generate.ts b/src/commands/generate.ts index e6a7690..14eb218 100644 --- a/src/commands/generate.ts +++ b/src/commands/generate.ts @@ -8,11 +8,11 @@ import template from "lodash.template"; import ora from "ora"; import prompts from "prompts"; import { z } from "zod"; -import { discoverZodSchemas } from "../utils/discover-zod"; import { getFormFields } from "../utils/form-fields"; import { getConfig } from "../utils/get-config"; import { handleError } from "../utils/handle-error"; import { logger } from "../utils/logger"; +import { parseZodSchemasFromFile } from "../utils/parse-zod"; import { formTemplate } from "../utils/templates/form"; import { transform } from "../utils/transformers"; @@ -39,7 +39,7 @@ export const generate = new Command() process.exit(1); } - const zodSchemas = discoverZodSchemas(config, options.schema); + const zodSchemas = parseZodSchemasFromFile(config, options.schema); if (Object.keys(zodSchemas).length === 0) { logger.error("No Zod schemas found in the specified file."); @@ -103,7 +103,7 @@ export const generate = new Command() process.exit(1); } - const { components, imports } = getFormFields( + const { components, imports, functions } = getFormFields( zodSchemas[selectedSchema].schema, ); @@ -112,9 +112,9 @@ export const generate = new Command() schema: selectedSchema, formName: camelCase(name).charAt(0).toUpperCase() + camelCase(name).slice(1), - defaultValues: "{}", + functions, components, - schemaImport: zodSchemas[selectedSchema].importStr, + schemaImport: zodSchemas[selectedSchema].import, imports, }), filename: `${name}.tsx`, diff --git a/src/utils/discover-zod.ts b/src/utils/discover-zod.ts deleted file mode 100644 index 98b6cef..0000000 --- a/src/utils/discover-zod.ts +++ /dev/null @@ -1,134 +0,0 @@ -import path from "node:path"; -import { type Expression, type Node, Project, SyntaxKind } from "ts-morph"; -import type { Config } from "./get-config"; -import { logger } from "./logger"; - -export type ParsedSchema = Record; - -export type ParsedSchemaValue = { - type: string; - children?: Record; - options?: string[]; -}; - -interface SchemaInfo { - importStr: string; - schema: Record; -} - -const UNSUPPORTED_TYPES = ["z.array", "z.record"]; - -export function discoverZodSchemas( - config: Config, - filePath: string, -): Record { - const project = new Project({ - tsConfigFilePath: path.resolve(process.cwd(), "tsconfig.json"), - }); - const sourceFile = project.addSourceFileAtPath(filePath); - - const schemas: Record = {}; - - const variableDeclarations = sourceFile.getDescendantsOfKind( - SyntaxKind.VariableDeclaration, - ); - - for (const declaration of variableDeclarations) { - const initializer = declaration.getInitializer(); - - if (initializer && isZodObjectSchema(initializer)) { - const schemaName = declaration.getName(); - - // todo: apply current ts config paths if needed - let relativePath = path.relative(config.resolvedPaths.forms, filePath); - relativePath = relativePath.replace(/\.(ts|tsx|js|jsx)$/, ""); - relativePath = relativePath.replace(/\\/g, "/"); - if (!path.isAbsolute(relativePath) && !relativePath.startsWith(".")) { - relativePath = `./${relativePath}`; - } - - const isDefaultExport = - sourceFile - .getExportedDeclarations() - .get("default") - ?.some((d) => d === declaration) || false; - - schemas[schemaName] = { - schema: extractObjectSchema(initializer), - importStr: isDefaultExport - ? `import ${schemaName} from "${relativePath}";` - : `import { ${schemaName} } from "${relativePath}";`, - }; - } - } - - return schemas; -} - -function isZodObjectSchema(node: Node): boolean { - if (node.getKind() === SyntaxKind.CallExpression) { - const callExpression = node.asKind(SyntaxKind.CallExpression); - const expression = callExpression?.getExpression(); - - if (expression?.getKind() === SyntaxKind.PropertyAccessExpression) { - const propertyAccess = expression.asKind( - SyntaxKind.PropertyAccessExpression, - ); - const objectName = propertyAccess?.getExpression().getText(); - const propertyName = propertyAccess?.getName(); - - return objectName === "z" && propertyName === "object"; - } - } - - return false; -} - -function extractObjectSchema(node: Node): ParsedSchema { - const schemaObject: ParsedSchema = {}; - - const objectLiteral = node.getDescendantsOfKind( - SyntaxKind.ObjectLiteralExpression, - )[0]; - - for (const property of objectLiteral.getChildrenOfKind( - SyntaxKind.PropertyAssignment, - )) { - const propertyName = property.getName(); - const propertyValue = property.getInitializer(); - - if (propertyValue) { - schemaObject[propertyName] = extractZodType(propertyValue); - } - } - - return schemaObject; -} - -function extractZodType(node: Expression): ParsedSchemaValue { - const text = node.getText(); - - const unsupported = UNSUPPORTED_TYPES.find((type) => text.startsWith(type)); - if (unsupported) { - logger.warn(`${unsupported} type is not currently supported.`); - return { type: "unsupported" }; - } - - if (text.startsWith("z.object(")) - return { type: "object", children: extractObjectSchema(node) }; - - if (text.includes(".datetime(")) return { type: "datetime" }; - if (text.includes(".date(")) return { type: "date" }; - - if (text.startsWith("z.enum(")) { - const enumValues = node - .getFirstChildByKind(SyntaxKind.ArrayLiteralExpression) - ?.getElements() - .map((element) => element.getText().replace(/['"]/g, "")); - return { type: "enum", options: enumValues || [] }; - } - - const type = text.split(".")[1]?.split("(")[0]; - - return { type: type ?? "unknown" }; -} diff --git a/src/utils/form-fields.ts b/src/utils/form-fields.ts index e24f60a..e61358c 100644 --- a/src/utils/form-fields.ts +++ b/src/utils/form-fields.ts @@ -1,81 +1,194 @@ import camelCase from "lodash.camelcase"; import startCase from "lodash.startcase"; import template from "lodash.template"; -import type { ParsedSchema, ParsedSchemaValue } from "./discover-zod"; +import { z } from "zod"; +import { logger } from "./logger"; +import { arrayFieldTemplate } from "./templates/array-field"; import { formFieldTemplate } from "./templates/form-field"; import { inputs, optionItem } from "./templates/inputs"; -export function getFormFields(schema: ParsedSchema): { +type FormFieldsResult = { imports: string; components: string; -} { - const flattenedSchema = flattenSchema(schema); + functions: string; +}; + +export function getFormFields(schema: z.ZodTypeAny): FormFieldsResult { const components: string[] = []; + const functions: string[] = []; const imports: Set = new Set(); - for (const [key, value] of Object.entries(flattenedSchema)) { - const { component, import: importStatement } = getInputComponent(value); - const formField = template(formFieldTemplate)({ - name: key, - label: getFieldLabel(key), - component, - }); - - components.push(formField); - imports.add(importStatement); - } + processSchema(schema, "", components, imports, functions); return { - imports: Array.from(imports).join("\n"), - components: components.join("\n"), + imports: Array.from(imports) + .filter((importStatement) => importStatement) + .join("\n"), + components: components.join(""), + functions: functions.join(""), }; } -function flattenSchema( - schema: ParsedSchema | ParsedSchemaValue, +function processSchema( + schema: z.ZodTypeAny, prefix = "", -): ParsedSchema { - const flattened: ParsedSchema = {}; + components: string[] = [], + imports: Set = new Set(), + functions: string[] = [], +): FormFieldsResult { + if (schema instanceof z.ZodNullable || schema instanceof z.ZodOptional) { + return processSchema( + schema.unwrap(), + prefix, + components, + imports, + functions, + ); + } - for (const [key, value] of Object.entries(schema)) { - const newKey = prefix ? `${prefix}.${key}` : key; + if (schema instanceof z.ZodObject) { + return processObjectSchema(schema, prefix, components, imports, functions); + } - if (value.type === "object") { - Object.assign(flattened, flattenSchema(value.children, newKey)); - } else { - flattened[newKey] = value; - } + if (schema instanceof z.ZodArray) { + return processArraySchema(schema, prefix, components, imports, functions); } - return flattened; + // Process primitive types + const { component, import: importStatement } = getInputComponent( + schema, + prefix, + ); + components.push(component); + imports.add(importStatement); + + return { + imports: Array.from(imports).join(""), + components: components.join(""), + functions: functions.join(""), + }; } -function getFieldLabel(key: string): string { - const parts = key.includes(".") ? key.split(".") : [key]; +function processObjectSchema( + schema: z.ZodObject, + prefix: string, + components: string[], + imports: Set, + functions: string[], +): FormFieldsResult { + for (const [key, value] of Object.entries(schema.shape)) { + const newKey = prefix ? `${prefix}.${key}` : key; + processSchema( + value as z.ZodTypeAny, + newKey, + components, + imports, + functions, + ); + } - return parts.map((part) => startCase(camelCase(part))).join(" "); + return { + imports: Array.from(imports).join(""), + components: components.join(""), + functions: functions.join(""), + }; } -function getInputComponent(field: ParsedSchemaValue): { +function processArraySchema( + schema: z.ZodArray, + prefix: string, + components: string[], + imports: Set, + functions: string[], +): FormFieldsResult { + if (schema.element instanceof z.ZodObject) { + const { components: children } = processSchema( + schema.element, + `${prefix}.\${index}`, + ); + + const defaultValues = getObjectDefaultValue(schema.element); + const arrayFieldComponent = template(arrayFieldTemplate.component)({ + children, + defaultValues: JSON.stringify(defaultValues).replace( + /"([^"]+)":/g, + "$1:", + ), + }); + + const arrayFieldFunctions = template(arrayFieldTemplate.functions)({ + name: prefix, + }); + + components.push(arrayFieldComponent); + imports.add(arrayFieldTemplate.import); + functions.push(arrayFieldFunctions); + } else { + logger.warn(`Only objects are supported in arrays, skipping ${prefix}`); + } + + return { + imports: Array.from(imports).join(""), + components: components.join(""), + functions: functions.join(""), + }; +} + +function getInputComponent( + field: z.ZodTypeAny, + prefix: string, +): { component: string; import: string; } { - const input = inputs[field.type]; + const input = inputs[field.constructor.name]; + const inputProps = { + children: "", + }; if (!input) { + logger.warn(`Unsupported field type: ${field.constructor.name}`); return { component: "", import: "", }; - // throw new Error(`No input component found for type: ${type}`); } + if (field instanceof z.ZodEnum) { + inputProps.children = field.options + .map((option: string) => template(optionItem)({ option })) + .join("\n"); + } + + const name = prefix.includes("${") ? `{\`${prefix}\`}` : `"${prefix}"`; + return { ...input, - component: template(input.component)({ - options: field.options - ?.map((option) => template(optionItem)({ option })) - .join(""), + component: template(formFieldTemplate)({ + name, + label: getFieldLabel(prefix), + input: template(input.component)(inputProps), }), }; } + +function getFieldLabel(key: string): string { + const parts = key.includes(".") ? key.split(".") : [key]; + return parts.map((part) => startCase(camelCase(part))).join(" "); +} + +function getObjectDefaultValue( + field: z.ZodObject, +): Record { + // todo: make recursive ? + const defaultValues: Record = {}; + + for (const [key, value] of Object.entries(field.shape)) { + const defaultValue = inputs[value.constructor.name]?.defaultValue; + if (typeof defaultValue !== "undefined") { + defaultValues[key] = defaultValue; + } + } + + return defaultValues; +} diff --git a/src/utils/parse-zod.ts b/src/utils/parse-zod.ts new file mode 100644 index 0000000..3615ab3 --- /dev/null +++ b/src/utils/parse-zod.ts @@ -0,0 +1,87 @@ +import * as path from "node:path"; +import * as vm from "node:vm"; +import { + type Node, + Project, + type SourceFile, + SyntaxKind, + type VariableDeclaration, +} from "ts-morph"; +import { z } from "zod"; +import type { Config } from "./get-config"; + +export type ParsedSchema = { + schema: z.ZodObject; + import: string; +}; + +export function parseZodSchemasFromFile( + config: Config, + filePath: string, +): Record { + const project = new Project(); + const sourceFile = project.addSourceFileAtPath(filePath); + const schemas: Record = {}; + + for (const declaration of sourceFile.getVariableDeclarations()) { + const initializer = declaration.getInitializer(); + if (initializer && isZodObjectSchema(initializer)) { + const schemaName = declaration.getName(); + const schemaCode = initializer.getText(); + schemas[schemaName] = { + schema: evaluateZodSchema(schemaCode, filePath), + import: buildImportString(config, sourceFile, declaration, filePath), + }; + } + } + + return schemas; +} + +function buildImportString( + config: Config, + sourceFile: SourceFile, + declaration: VariableDeclaration, + filePath: string, +) { + const isDefaultExport = sourceFile + .getExportedDeclarations() + .get("default") + ?.some((d) => d === declaration); + + let importPath = path.relative(config.resolvedPaths.forms, filePath); + importPath = importPath.replace(/\.(ts|tsx|js|jsx)$/, ""); + importPath = importPath.replace(/\\/g, "/"); + + const importName = declaration.getName(); + + if (isDefaultExport) { + return `import ${importName} from "${importPath}";`; + } + + return `import { ${importName} } from "${importPath}";`; +} + +function isZodObjectSchema(node: Node) { + return ( + node.getKind() === SyntaxKind.CallExpression && + node.getFirstChild()?.getText().startsWith("z.object") + ); +} + +function evaluateZodSchema( + schemaCode: string, + filePath: string, +): z.ZodObject { + const context = { + z, + require: (id: string) => { + if (id === "zod") return z; + return require(path.resolve(path.dirname(filePath), id)); + }, + console, + }; + + const script = `const schema = ${schemaCode}; schema;`; + return vm.runInNewContext(script, context) as z.ZodObject; +} diff --git a/src/utils/templates/array-field.ts b/src/utils/templates/array-field.ts new file mode 100644 index 0000000..426f302 --- /dev/null +++ b/src/utils/templates/array-field.ts @@ -0,0 +1,35 @@ +export const arrayFieldTemplate = { + import: `import { useFieldArray } from 'react-hook-form'; + import { XIcon, PlusIcon } from "lucide-react";`, + functions: ` + const { fields, append, remove } = useFieldArray({ + control: form.control, + name: "<%= name %>", + }); + `, + component: ` +
+ {fields.map((field, index) => ( +
+ <%= children %> + +
+ )) + } + +
`, +}; diff --git a/src/utils/templates/form-field.ts b/src/utils/templates/form-field.ts index 9abca3f..3479a06 100644 --- a/src/utils/templates/form-field.ts +++ b/src/utils/templates/form-field.ts @@ -1,12 +1,11 @@ -export const formFieldTemplate = ` - render={({ field }) => ( <%= label %> - <%= component %> + <%= input %> diff --git a/src/utils/templates/form.ts b/src/utils/templates/form.ts index 3a498f0..ae30b37 100644 --- a/src/utils/templates/form.ts +++ b/src/utils/templates/form.ts @@ -1,32 +1,33 @@ export const formTemplate = ` -import { zodResolver } from "@hookform/resolvers/zod" -import { useForm } from "react-hook-form" -import { z } from "zod" +import { zodResolver } from "@hookform/resolvers/zod"; +import { useForm } from "react-hook-form"; +import type { z } from "zod"; <%= schemaImport %> <%= imports %> import { Button } from "@/registry/ui/button" import { Form, FormControl, - FormDescription, FormField, FormItem, FormLabel, FormMessage, } from "@/registry/ui/form" -const formSchema = <%= schema %> +const formSchema = <%= schema %>; export function <%= formName %>() { const form = useForm>({ resolver: zodResolver(formSchema), - defaultValues: <%= defaultValues %>, - }) + defaultValues: {}, + }); function onSubmit(values: z.infer) { // Handle form submission console.log(values) - } + }; + + <%= functions %> return (
diff --git a/src/utils/templates/inputs.ts b/src/utils/templates/inputs.ts index 26238be..c8f50bc 100644 --- a/src/utils/templates/inputs.ts +++ b/src/utils/templates/inputs.ts @@ -1,37 +1,39 @@ -export const inputs: Record = { - string: { - import: "import { Input } from '@/registry/ui/input'", - component: "", - }, - number: { - import: "import { Input } from '@/registry/ui/input'", - component: "", - }, - // todo: implement custom registry datepicker - date: { - import: "import { Input } from '@/registry/ui/input'", +import { z } from "zod"; + +export const inputs: Record< + string, + { + import: string; + component: string; + defaultValue?: string | boolean | number; + } +> = { + [z.ZodString.name]: { + import: "import { Input } from '@/registry/ui/input';", component: "", + defaultValue: "", }, - // todo: implement custom registry datetime picker - datetime: { - import: "import { Input } from '@/registry/ui/input'", + [z.ZodNumber.name]: { + import: "import { Input } from '@/registry/ui/input';", component: "", + defaultValue: 0, }, - boolean: { - import: "import { Checkbox } from '@/registry/ui/checkbox'", + [z.ZodBoolean.name]: { + import: "import { Checkbox } from '@/registry/ui/checkbox';", component: "", + defaultValue: false, }, - enum: { + [z.ZodEnum.name]: { import: - "import { Select, SelectTrigger, SelectValue, SelectContent, SelectGroup, SelectItem } from '@/registry/ui/select'", - component: ` - <%= options %> + <%= children %> `, diff --git a/src/utils/transformers/transform-import.ts b/src/utils/transformers/transform-import.ts index b0f710e..e954437 100644 --- a/src/utils/transformers/transform-import.ts +++ b/src/utils/transformers/transform-import.ts @@ -7,18 +7,13 @@ export const transformImport: Transformer = async ({ sourceFile, config }) => { for (const importDeclaration of importDeclarations) { const moduleSpecifier = importDeclaration.getModuleSpecifierValue(); - console.log(moduleSpecifier); // Replace @/registry/ui with the components alias. if (moduleSpecifier.startsWith("@/registry/ui")) { if (config.aliases.ui) { - console.log(moduleSpecifier); - console.log("ui", config.aliases.ui); - console.log("importDeclaration", importDeclaration.getText()); importDeclaration.setModuleSpecifier( moduleSpecifier.replace(/^@\/registry\/ui/, config.aliases.ui), ); - console.log(importDeclaration.getText()); } else { importDeclaration.setModuleSpecifier( moduleSpecifier.replace( diff --git a/tsconfig.json b/tsconfig.json index c8c3fc4..70a2a83 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,31 +1,24 @@ { - "$schema": "https://json.schemastore.org/tsconfig", - "compilerOptions": { - "composite": false, - "declaration": true, - "declarationMap": true, - "esModuleInterop": true, - "forceConsistentCasingInFileNames": true, - "inlineSources": false, - "moduleResolution": "node", - "noUnusedLocals": false, - "noUnusedParameters": false, - "preserveWatchOutput": true, - "skipLibCheck": true, - "strict": true, - "isolatedModules": false, - "baseUrl": ".", - "paths": { - "@/*": [ - "./*" - ] - } - }, - "include": [ - "src/**/*.ts", - "tsup.config.ts" - ], - "exclude": [ - "node_modules" - ] -} \ No newline at end of file + "$schema": "https://json.schemastore.org/tsconfig", + "compilerOptions": { + "composite": false, + "declaration": true, + "declarationMap": true, + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "inlineSources": false, + "moduleResolution": "node", + "noUnusedLocals": false, + "noUnusedParameters": false, + "preserveWatchOutput": true, + "skipLibCheck": true, + "strict": true, + "isolatedModules": false, + "baseUrl": ".", + "paths": { + "@/*": ["./*"] + } + }, + "include": ["src/**/*.ts", "tsup.config.ts"], + "exclude": ["node_modules"] +}