From ddf38e1b5df95844c6c85d59f63539785ce7ea63 Mon Sep 17 00:00:00 2001 From: 8ga Date: Wed, 29 Oct 2025 10:17:23 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20LLMBenchmarkTester.java?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- LLMBenchmarkTester.java | 730 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 730 insertions(+) create mode 100644 LLMBenchmarkTester.java diff --git a/LLMBenchmarkTester.java b/LLMBenchmarkTester.java new file mode 100644 index 0000000..9de952f --- /dev/null +++ b/LLMBenchmarkTester.java @@ -0,0 +1,730 @@ +import java.io.File; +import java.io.IOException; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.lang.reflect.Field; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.time.Duration; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.*; +import java.util.concurrent.*; +import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +/** + *

+ * JDK版本必须大于或等于21, 直接运行将生成一份bat脚本或shell脚本, 下载JDK可以在浏览器打开链接按需下载: + * https://www.azul.com/downloads/?version=java-21-lts&package=jdk#zulu + *

+ */ +public class LLMBenchmarkTester { + + public static final String SEP = "============================================================================================="; + + public static final Field[] PARAM_FIELD = ScriptParameter.class.getDeclaredFields(); + + public static final Pattern CONTENT_PATTERN = Pattern.compile("\"content\"\\s*:\\s*\"([^\"]*)\""); + + public static void main(String[] args) throws Exception { + if (args == null || args.length == 0) { + createRunScript(); + } else { + ScriptParameter param = readScriptParameter(args); + printScriptParam(param); + List executeContexts = new ArrayList<>(); + if (param.isTestChatModel()) { + TextQuestion textQuestion = new TextQuestion(param); + List tasks = submit( + param.modelName.split(","), + param.threadSize.split(","), + param, + textQuestion, + null + ); + if (!tasks.isEmpty()) { + executeContexts.addAll(tasks); + } + } + if (param.isTestVlModel()) { + ImageQuestion imageQuestion = new ImageQuestion(param); + List tasks = submit( + param.vlModelName.split(","), + param.threadSize.split(","), + param, + null, + imageQuestion + ); + if (!tasks.isEmpty()) { + executeContexts.addAll(tasks); + } + } + if (!executeContexts.isEmpty()) { + String today = LocalDate.now().format(DateTimeFormatter.ofPattern("yyyyMMdd")); + String logName = String.format("llm_bench_%s.log", today); + Path logPath = Path.of(System.getProperty("user.dir"), logName); + for (ExecuteContext executeContext : executeContexts) { + executeContext.latch.await(); + writeLog(logPath, executeContext); + executeContext.executor.shutdownNow(); + executeContext.sessionMap.clear(); + } + } + } + } + + private static void writeLog(Path logPath, ExecuteContext executeContext) throws IOException { + Collection values = executeContext.sessionMap.values(); + double avgResponse = values.stream().mapToLong(HttpContext::toEndMillis).filter(d -> d > 0L).average().orElse(0D); + long totalTime = values.stream().mapToLong(HttpContext::toFinishMillis).sum(); + long successNum = values.stream().filter(d -> d.success).count(); + long maxResponse = values.stream().mapToLong(HttpContext::toEndMillis).max().orElse(0L); + int outTextLength = values.stream().mapToInt(d -> d.outTexts.stream().mapToInt(s -> s != null ? s.length() : 0).sum()).sum(); + int outTextCount = values.stream().mapToInt(d -> d.outTexts != null ? d.outTexts.size() : 0).sum(); + String format = """ + ------------------------------------------ + 模型: %s + 并发量: %d + 问题数量: %d + 成功: %d + 首次响应最长耗时: %d毫秒 + 首次响应平均耗时: %f毫秒 + 一共输出: %d字, 共输出%d次, 共计耗时:%d毫秒 + ------------------------------------------ + %s + """; + String msg = String.format( + format, + executeContext.model, + executeContext.threadSize, + executeContext.sessionMap.size(), + successNum, + maxResponse, + avgResponse, + outTextLength, + outTextCount, + totalTime, + System.lineSeparator() + ); + Files.writeString(logPath, msg, StandardCharsets.UTF_8, StandardOpenOption.CREATE, StandardOpenOption.APPEND); + } + + private static List submit(String[] models, String[] threadSizeStr, ScriptParameter param, TextQuestion textQuestion, ImageQuestion imageQuestion) { + List threadSizeList = Arrays.stream(threadSizeStr).map(s -> Integer.parseInt(s.strip())).toList(); + List executeContexts = new ArrayList<>(); + for (Integer threadSize : threadSizeList) { + for (String model : models) { + if (textQuestion != null) { + executeContexts.add(execute(threadSize, model, param, textQuestion.getRequestParams(model))); + } else if (imageQuestion != null) { + executeContexts.add(execute(threadSize, model, param, imageQuestion.getRequestParams(model))); + } + } + } + return executeContexts; + } + + private static ExecuteContext execute(int threadSize, String model, ScriptParameter param, List requestParams) { + CountDownLatch latch = new CountDownLatch(requestParams.size()); + ConcurrentHashMap sessionMap = new ConcurrentHashMap<>(requestParams.size()); + ExecutorService executorService = Executors.newFixedThreadPool(threadSize); + URI uri = URI.create(param.openAiApiHost); + executorService.execute(() -> { + for (String requestBody : requestParams) { + startHttp(uri, param.apiKey, requestBody, latch, sessionMap); + } + }); + return new ExecuteContext(model, threadSize, sessionMap, latch, executorService); + } + + record ExecuteContext(String model, + int threadSize, + ConcurrentHashMap sessionMap, + CountDownLatch latch, + ExecutorService executor) { + + } + + private static void startHttp(URI uri, String apiKey, String requestBody, CountDownLatch latch, Map sessionMap) { + HttpRequest httpRequest = HttpRequest.newBuilder() + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", "Bearer " + apiKey) + .timeout(Duration.ofSeconds(15)) + .POST(HttpRequest.BodyPublishers.ofString(requestBody)) + .build(); + try (HttpClient client = HttpClient.newHttpClient()) { + HttpContext context = new HttpContext(); + context.start = LocalDateTime.now(); + context.outTexts = new ArrayList<>(); + Flow.Subscriber subscriber = createResponseFluxHandler(context); + CompletableFuture> future = + client.sendAsync(httpRequest, HttpResponse.BodyHandlers.fromLineSubscriber(subscriber)); + handleHttpResponseFuture(future, latch, sessionMap, context); + } + } + + private static class HttpContext { + long sessionId = System.nanoTime(); + LocalDateTime start; + LocalDateTime end; + LocalDateTime completed; + boolean success; + List outTexts; + + public long toEndMillis() { + return this.end != null ? Duration.between(this.start, this.end).toMillis() : 0L; + } + + public long toFinishMillis() { + return this.completed != null ? Duration.between(this.start, this.completed).toMillis() : 0L; + } + } + + private static void handleHttpResponseFuture(CompletableFuture> future, + CountDownLatch latch, + Map sessionMap, + HttpContext context) { + future.whenComplete((response, exception) -> context.success = false) + .thenAccept(response -> { + context.success = response.statusCode() == 200; + sessionMap.putIfAbsent(context.sessionId, context); + latch.countDown(); + }).exceptionally(err -> { + context.success = false; + sessionMap.putIfAbsent(context.sessionId, context); + latch.countDown(); + return null; + }); + } + + private static Flow.Subscriber createResponseFluxHandler(HttpContext context) { + return new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + context.end = LocalDateTime.now(); + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(String item) { + if (item != null && !item.isEmpty()) { + Matcher matcher = CONTENT_PATTERN.matcher(item); + String group; + if (matcher.find() && null != (group = matcher.group(1)) && !group.isEmpty()) { + context.outTexts.add(group); + } + } + } + + @Override + public void onError(Throwable throwable) { + context.success = false; + } + + @Override + public void onComplete() { + context.completed = LocalDateTime.now(); +// System.out.println(context.outTexts); + } + }; + } + + private static class ImageQuestion { + private static final Map> cache = new ConcurrentHashMap<>(); + private static final String template = """ + { + "model": "${model}", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "这张图片里有什么?" + }, + { + "type": "image_url", + "image_url": { + "url": "${imageBase64}" + } + } + ] + } + ], + "stream": true + } + """.strip(); + + List list; + + public String getImgHead(File file) { + if (file.getName().endsWith("png")) { + return "image/png"; + } + if (file.getName().endsWith("jpg") || file.getName().endsWith("jpeg")) { + return "image/jpeg"; + } + return null; + } + + public String tryEncodeBase64(File file, Path path) { + String imgHead = getImgHead(file); + if (imgHead == null || imgHead.isBlank()) { + return null; + } + try { + return "data:" + imgHead + ";base64," + Base64.getEncoder().encodeToString(Files.readAllBytes(path)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public List base64Image(ScriptParameter parameter) throws IOException { + try (var files = Files.list(Path.of(parameter.vlImgFolder))) { + return files.map(path -> { + File file = path.toFile(); + return tryEncodeBase64(file, path); + }).filter(Objects::nonNull).toList(); + } + } + + public ImageQuestion(ScriptParameter parameter) throws IOException { + List base64List = base64Image(parameter); + int imgNum = Integer.parseInt(parameter.imgSize); + this.list = new ArrayList<>(imgNum); + do { + for (String item : base64List) { + this.list.add(item); + if (this.list.size() == imgNum) { + break; + } + } + } while (this.list.size() < imgNum); + } + + public List getRequestParams(String model) { + List cacheParams = cache.get(model); + if (cacheParams != null && !cacheParams.isEmpty()) { + return cacheParams; + } + List params = this.list.stream().map(s -> this.toJsonParam(model, s)).toList(); + cache.put(model, params); + return params; + } + + public String toJsonParam(String model, String imageBase64) { + return template.replace("${model}", model).replace("${imageBase64}", imageBase64); + } + } + + private static class TextQuestion { + private static final Map> cache = new ConcurrentHashMap<>(); + private static final String template = """ + { + "model": "${model}", + "messages": [ + { + "role": "user", + "content": "${prompt}" + } + ], + "stream": true + } + """.strip(); + // 解析文件得到的问题列表 + List list; + + public TextQuestion(ScriptParameter parameter) throws IOException { + this.list = Files.readAllLines(Path.of(parameter.chatDatasetsPath)); + } + + public List getRequestParams(String model) { + List cacheParams = cache.get(model); + if (cacheParams != null && cacheParams.isEmpty()) { + return cacheParams; + } + List params = this.list.stream().map(s -> this.toJsonParam(model, s)).toList(); + cache.put(model, params); + return params; + } + + public String toJsonParam(String model, String prompt) { + return template.replace("${model}", model).replace("${prompt}", prompt); + } + } + + private static void printScriptParam(ScriptParameter param) throws IllegalAccessException { + System.out.println("本次执行脚本的参数如下:"); + for (Field field : PARAM_FIELD) { + if (field.isAnnotationPresent(EnvName.class)) { + System.out.println(SEP); + EnvName annotation = field.getAnnotation(EnvName.class); + field.setAccessible(true); + Object value = field.get(param); + System.out.printf("参数: %s 数值: %s%n", annotation.value(), value); + } + } + System.out.println(SEP); + } + + private static void createRunScript() throws IOException { + String osName = System.getProperty("os.name").toLowerCase(); + if (osName.contains("win")) { + generateWindowsBat(); + } else { + generateShellScript(); + } + } + + private static File createScripeFile(String extName) { + String date = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMdd")); + File file = new File(String.format("llm_benchmark_tester_%s.%s", date, extName)); + if (file.exists()) { + throw new RuntimeException(String.format("您可以通过 %s 脚本直接运行", file.getAbsolutePath())); + } + boolean createBat; + try { + createBat = file.createNewFile(); + } catch (Exception e) { + throw new RuntimeException(String.format("创建 %s 脚本文件异常: %s", file.getAbsolutePath(), e.getMessage()), e); + } + if (!createBat) { + throw new RuntimeException(String.format("创建 %s 脚本文件失败", file.getAbsolutePath())); + } + System.out.println("已为你生成一份脚本, 请修改脚本中的环境变量, 使用脚本运行"); + System.out.printf("脚本的存储路径 %s%n", file.getAbsolutePath()); + System.out.println("运行脚本之前, 请确保脚本文件的换行符与系统相匹配, 否则会无法运行"); + return file; + } + + private static void writeScriptFile(File file, String template, BiConsumer> eachFunc) throws IOException { + List envLines = new ArrayList<>(); + for (Field field : PARAM_FIELD) { + if (field.isAnnotationPresent(EnvName.class)) { + EnvName annotation = field.getAnnotation(EnvName.class); + eachFunc.accept(annotation, envLines); + } + } + if (!envLines.isEmpty()) { + String envList = envLines.stream().collect(Collectors.joining(System.lineSeparator())); + String script = template.replace("${ENV_LINES}", envList); + Files.writeString(file.toPath(), script.strip(), StandardCharsets.UTF_8); + } + } + + private static void generateWindowsBat() throws IOException { + File file = createScripeFile("bat"); + String batTemplate = """ + @echo off + + :: java可执行文件路径, 不是JAVA_HOME, 是完整的java可执行文件路径, 例如: D:\\jdk-2108\\bin\\java + set JAVA_BIN= + + :: 脚本存放路径, 例如: E:\\JExample\\src\\LLMBenchmarkTester.java + set SCRIPT_PATH= + + ${ENV_LINES} + + :: 基于环境变量的方式执行, 交互式命令行执行用这个命令: %JAVA_BIN% %SCRIPT_PATH% -p input + %JAVA_BIN% %SCRIPT_PATH% -p env + + pause + """; + writeScriptFile(file, batTemplate, (envName, envLines) -> { + envLines.add(":: " + envName.desc()); + envLines.add("set " + envName.value() + "="); + }); + } + + private static void generateShellScript() throws IOException { + File file = createScripeFile("sh"); + String bashTemplate = """ + #!/bin/bash + # java可执行文件路径, 不是JAVA_HOME, 是完整的java可执行文件路径, 例如: /opt/jdk-2108/bin/java + JAVA_BIN="" + + # 脚本存放路径, 例如: /home/user/JExample/src/LLMBenchmarkTester.java + SCRIPT_PATH="" + + ${ENV_LINES} + + # 基于环境变量的方式执行, 交互式命令行执行用这个命令: $JAVA_BIN $SCRIPT_PATH -p input + "$JAVA_BIN" "$SCRIPT_PATH" -p env + """; + writeScriptFile(file, bashTemplate, (envName, envLines) -> { + envLines.add("# " + envName.desc()); + envLines.add(envName.value() + "="); + }); + } + + private static ScriptParameter readScriptParameter(String[] args) throws IllegalAccessException { + if (args != null && args.length > 0) { + boolean p = Arrays.stream(args).anyMatch(s -> s.equalsIgnoreCase("-p")); + if (p && Arrays.stream(args).anyMatch(s -> s.equalsIgnoreCase("env"))) { + return initScriptParamFromEnv(); + } + if (p && Arrays.stream(args).anyMatch(s -> s.equalsIgnoreCase("input"))) { + return initScriptParamFromAsk(); + } + } + throw new RuntimeException("命令错误, 请检查参数是否正确"); + } + + private static ScriptParameter initScriptParamFromEnv() throws IllegalAccessException { + ScriptParameter param = new ScriptParameter(); + param.channel = 1; + for (Field field : PARAM_FIELD) { + if (field.isAnnotationPresent(EnvName.class)) { + EnvName envName = field.getAnnotation(EnvName.class); + String fieldValue = System.getenv(envName.value()); + String formatValue = formatValue(fieldValue, field); + if (field.isAnnotationPresent(NotBlank.class) && (formatValue == null || formatValue.isBlank())) { + throw new RuntimeException(String.format("环境变量[%s]不能为空或空白字符", envName.value())); + } + if (!isValidValue(formatValue, field)) { + throw new RuntimeException(String.format("环境变量[%s]数值不合法, 当前值:[%s]", envName.value(), formatValue)); + } + field.setAccessible(true); + field.set(param, fieldValue); + } + } + return param; + } + + private static ScriptParameter initScriptParamFromAsk() throws IllegalAccessException { + ScriptParameter param = new ScriptParameter(); + param.channel = 2; + Scanner scanner = new Scanner(System.in); + for (Field field : PARAM_FIELD) { + if (field.isAnnotationPresent(AskUser.class)) { + AskUser askUser = field.getAnnotation(AskUser.class); + System.out.println(askUser.value() + ":"); + boolean isNotBlank = field.isAnnotationPresent(NotBlank.class); + for (; ; ) { + String userInput = scanner.nextLine().trim(); + String formatValue = formatValue(userInput, field); + // 允许为空并且输入值为空 + if (!isNotBlank && (formatValue == null || formatValue.isBlank())) { + break; + } + // 非空并且输入值合法 + if (isNotBlank && formatValue != null && !formatValue.isBlank() && isValidValue(formatValue, field)) { + field.setAccessible(true); + field.set(param, formatValue); + break; + } + System.out.print("请重新输入:"); + } + } + } + return param; + } + + // 顺序校验 + private static Boolean isValidValue(String formatValue, Field field) { + if (field.isAnnotationPresent(Validator.class)) { + Validator anno = field.getAnnotation(Validator.class); + return Arrays.stream(anno.value()) + .map(validator -> Constants.TEXT_VALIDATOR.get(validator.name()).apply(formatValue)) + .allMatch(Boolean.TRUE::equals); + } + return Boolean.TRUE; + } + + // 顺序格式化 + private static String formatValue(String fieldValue, Field field) { + if (field.isAnnotationPresent(Formatter.class)) { + Formatter anno = field.getAnnotation(Formatter.class); + for (TextFormater fmt : anno.value()) { + fieldValue = Constants.TEXT_FORMATER.get(fmt.name()).apply(fieldValue); + } + } + return fieldValue; + } + + @Retention(RetentionPolicy.RUNTIME) + @Target({ElementType.FIELD}) + public @interface AskUser { + String value(); + } + + @Retention(RetentionPolicy.RUNTIME) + @Target({ElementType.FIELD}) + public @interface EnvName { + String value() default ""; + + String desc(); + } + + @Retention(RetentionPolicy.RUNTIME) + @Target({ElementType.FIELD}) + public @interface NotBlank { + + } + + @Retention(RetentionPolicy.RUNTIME) + @Target({ElementType.FIELD}) + public @interface Validator { + TextValidator[] value(); + } + + @Retention(RetentionPolicy.RUNTIME) + @Target({ElementType.FIELD}) + public @interface Formatter { + TextFormater[] value(); + } + + public enum TextValidator { + MUST_URL, MUST_FOLDER, MUST_TXT, MUST_NUM; + } + + public enum TextFormater { + STRIP, COMMA_CN_2_EN; + } + + private static class Constants { + // 去除字符串两端空白字符和制表符 + public static final Function STRIP_FORMATTER = + str -> Optional.ofNullable(str).map(java.lang.String::strip).orElse(""); + + // 中文逗号替换成英文逗号 + public static final Function COMMA_CN_2_EN = + str -> Optional.ofNullable(str).map(d -> d.replaceAll(",", ",")).orElse(""); + + // 字符串必须是一个http链接 + public static final Function URL_VALIDATOR = + str -> str != null && (str.startsWith("http://") || str.startsWith("https://")); + + // 字符串必须是一个合法的文件路径且已存在的文件夹 + public static final Function FOLDER_VALIDATOR = str -> { + if (str != null && !str.isBlank()) { + try { + File file = new File(str); + return file.exists() && file.isDirectory(); + } catch (Exception e) { + return false; + } + } + return true; + }; + + // 字符串必须是一个合法的文件路径且已存在的txt文件 + public static final Function TXT_FILE_VALIDATOR = str -> { + if (str != null && !str.isBlank()) { + try { + File file = new File(str); + return file.exists() && file.isFile() && file.getName().endsWith(".txt"); + } catch (Exception e) { + return false; + } + } + return true; + }; + + // 字符串必须是一个整数 + public static final Function NUMBER_VALIDATOR = str -> { + if (str != null && !str.isBlank()) { + try { + Integer.parseInt(str); + } catch (Exception e) { + return false; + } + } + return true; + }; + + // 文本格式化工具注册表 + public static final Map> TEXT_FORMATER = + Map.of( + TextFormater.STRIP.name(), Constants.STRIP_FORMATTER, + TextFormater.COMMA_CN_2_EN.name(), Constants.COMMA_CN_2_EN + ); + + // 文本验证工具注册表 + public static final Map> TEXT_VALIDATOR = + Map.of( + TextValidator.MUST_URL.name(), URL_VALIDATOR, + TextValidator.MUST_FOLDER.name(), FOLDER_VALIDATOR, + TextValidator.MUST_TXT.name(), TXT_FILE_VALIDATOR, + TextValidator.MUST_NUM.name(), NUMBER_VALIDATOR + ); + } + + public static class ScriptParameter { + + // 1=环境变量, 2=交互式命令行 + int channel; + + @NotBlank + @EnvName(value = "BENCH_LLM_API_HOST", desc = "OpenAI API 的访问地址, 例如: http://localhost:8080/v1/chat/completions") + @Validator(value = TextValidator.MUST_URL) + @Formatter(value = TextFormater.STRIP) + @AskUser(value = "请输入 OpenAI API 的访问地址 (例如: http://localhost:8080/v1/chat/completions)") + String openAiApiHost; + + @NotBlank + @EnvName(value = "BENCH_LLM_API_KEY", desc = "ApiKey或者叫API令牌") + @Formatter(value = TextFormater.STRIP) + @AskUser(value = "请输入ApiKey或者叫API令牌") + String apiKey; + + @NotBlank + @EnvName(value = "BENCH_THREAD_SIZE_ARRAY", desc = "请输入线程池配置, 示例值: 10,50,100") + @Formatter(value = {TextFormater.STRIP, TextFormater.COMMA_CN_2_EN}) + @AskUser(value = "请输入线程池配置 (示例值: 10,50,100)") + String threadSize; + + @EnvName(value = "BENCH_LLM_MODEL_NAME", desc = "文本模型名称, 多个使用英文逗号隔开, 如果不测试文生文模型可以不设置, 示例值: qwen2.5,qwen3") + @Formatter(value = {TextFormater.STRIP, TextFormater.COMMA_CN_2_EN}) + @AskUser(value = "请输入文本模型名称, 多个使用英文逗号隔开, 如果不测试文生文模型可以直接回车 (示例值: qwen2.5,qwen3)") + String modelName; + + @EnvName(value = "BENCH_LLM_VL_MODEL_NAME", desc = "VL模型名称, 多个用英文逗号隔开, 如果不测试VL模型可以不设置") + @Formatter(value = {TextFormater.STRIP, TextFormater.COMMA_CN_2_EN}) + @AskUser(value = "请输入VL模型名称, 多个用英文逗号隔开, 如果不测试VL模型可以直接回车") + String vlModelName; + + @EnvName(value = "BENCH_LLM_VL_IMG_FOLDER", desc = "调用VL模型的图片存储目录, 如果不测试VL模型可以不设置, 示例值: /home/image") + @Validator(value = TextValidator.MUST_FOLDER) + @Formatter(value = {TextFormater.STRIP}) + @AskUser(value = "请输入调用VL模型的图片存储目录, 如果不测试VL模型可以直接回车 (示例值: /home/image)") + String vlImgFolder; + + @EnvName(value = "BENCH_LLM_CHAT_MODEL_DATASETS", desc = "文生文测试数据集的文件路径, 如果不测试文生文模型可以不设置, 必须是一个.txt文件 (示例值: /home/datasets.txt)") + @Validator(value = TextValidator.MUST_TXT) + @AskUser(value = "请输入文生文测试数据集的文件路径, 必须是一个.txt文件, 如果不测试文生文模型可以直接回车 (示例值: /home/datasets.txt)") + String chatDatasetsPath; + + @EnvName(value = "BENCH_LLM_VL_IMG_SIZE", desc = "调用VL模型的测试图片数量, 如果文件夹下的图片数量不够, 会复制直到到足够数量, 如果不测试VL模型可以不设置 (示例值: 300)") + @Validator(value = TextValidator.MUST_NUM) + @Formatter(value = TextFormater.STRIP) + @AskUser("请输入调用VL模型的测试图片数量, 如果文件夹下的图片数量不够, 会复制直到到足够数量, 如果不测试VL模型可以直接回车 (示例值: 300)") + String imgSize; + + public boolean isTestChatModel() { + return this.modelName != null && !this.modelName.isBlank() + && this.chatDatasetsPath != null && !this.chatDatasetsPath.isBlank(); + } + + public boolean isTestVlModel() { + return this.vlModelName != null && !this.vlModelName.isBlank() + && this.vlImgFolder != null && !this.vlImgFolder.isBlank() + && this.imgSize != null && !this.imgSize.isBlank(); + } + } + +}