java(kotlin) ai框架djl

发布于:2024-06-16 ⋅ 阅读:(21) ⋅ 点赞:(0)

DJL(Deep Java Library)是一个开源的深度学习框架,由AWS推出,DJL支持多种深度学习后端,包括但不限于:

MXNet:由Apache软件基金会支持的开源深度学习框架。
PyTorch:广泛使用的开源机器学习库,由Facebook的AI研究团队开发。
TensorFlow:由Google开发的另一个流行的开源机器学习框架。
DJL与Java生态系统紧密集成,可以与Spring Boot、Quarkus等Java框架协同工作。

maven

 <!--        djl-->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.28.0</version>
        </dependency>
  <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.28.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-model-zoo</artifactId>
            <version>0.28.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>0.28.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
            <version>0.28.0</version>
        </dependency>
        <!--        /djl-->

Java DJL 架构图

┌──────────────────────────────┐
│          ModelZoo            │
├──────────────────────────────┤
│            Model             │
└───────────────┬──────────────┘
                │
      ┌─────────▼─────────┐
      │       Engine      │
      └───────┬─┬─────────┘
              │ │
      ┌───────▼─▼─────────┐
      │     NDManager     │
      └───────┬─┬─────────┘
              │ │
    ┌─────────▼─▼───────────┐
    │    Dataset 
    └─────────┬─────────────┘
              │
    ┌─────────▼─────────────┐
    │  Trainer / Predictor  │
    └───────────────────────┘

主要组件详细描述

1. ModelZoo 和 Model
  • ModelZoo:提供多种预训练模型

    ModelZoo 的功能
    1. 模型发现与下载
      • ModelZoo 提供了一种机制,可以从多种来源(例如模型提供商、在线仓库等)发现和下载预训练模型。
      • 例如,可以从 AWS S3、Hugging Face、TensorFlow Hub 等平台下载模型。
    2. 模型加载
      • ModelZoo 提供了方便的方法来加载模型,用户可以根据需求加载不同类型的模型(例如图像分类模型、对象检测模型、自然语言处理模型等)。
      • 加载模型时,可以指定模型的名称、版本、以及模型的参数配置。
    3. 模型管理
      • ModelZoo 帮助用户管理已下载和加载的模型,可以方便地查看、更新和删除模型。
      • 通过这种方式,可以有效地管理本地的模型资源,避免重复下载和浪费存储空间。

    示例

    
    import ai.djl.Application
    import ai.djl.Model
    import ai.djl.ModelException
    import ai.djl.modality.Classifications
    import ai.djl.modality.cv.Image
    import ai.djl.repository.zoo.Criteria
    import ai.djl.repository.zoo.ModelZoo
    import ai.djl.translate.TranslateException
    
    
    object ModelZooExample {
        @Throws(ModelException::class, TranslateException::class)
        @JvmStatic
        fun main(args: Array<String>) {
            // 定义模型的标准
            val criteria: Criteria<Image, Classifications> = Criteria.builder()
                .optApplication(Application.CV.IMAGE_CLASSIFICATION) // 应用场景:图像分类
                .setTypes(Image::class.java, Classifications::class.java) // 输入输出类型
                .optFilter("backbone", "resnet50") // 模型过滤条件
                .build()
    
            // 从 ModelZoo 加载模型
            val model: Model = ModelZoo.loadModel(criteria)
    
            // 使用模型进行推理
            // ...
        }
    }
    
    
    

    ModelZoo 的类与接口

    • ModelZoo:核心类,提供模型的下载和加载功能。
    • Criteria:定义模型加载的标准和过滤条件,用于指定所需模型的应用场景、输入输出类型等。
    • ModelLoader:用于实际执行模型的下载和加载操作。
  • Model:表示一个深度学习模型的接口,包含模型的加载、保存和运行等操作。

    • ai.djl.ModelZoo

      Key Methods:
      • Model loadModel(Criteria<?, ?> criteria): Loads a model based on the provided criteria.
      • ModelInfo getModel(ModelId modelId): Retrieves information about a specific model using its ModelId.
      • Set<ModelId> listModels(ZooModel<?, ?> model): Lists all models in the zoo that match the given model.

      ai.djl.ModelInfo Interface

      ModelInfo provides metadata about a model, including its name, description, and input/output information.

      Key Methods:
      • String getName(): Returns the name of the model.
      • String getDescription(): Provides a description of the model.
      • Shape getInputShape(): Returns the shape of the input tensor.
      • Shape getOutputShape(): Returns the shape of the output tensor.

      ai.djl.ModelId Class

      ModelId uniquely identifies a model in the model zoo. It includes information about the model’s group, name, and version.

      Key Fields:
      • String getGroup(): Gets the group name of the model.
      • String getName(): Gets the name of the model.
      • String getVersion(): Gets the version of the model.

      ai.djl.Application Enum

      Application enumerates different types of applications supported by the model zoo, such as IMAGE_CLASSIFICATION, OBJECT_DETECTION, etc.

      Key Values:
      • CV.IMAGE_CLASSIFICATION
      • CV.OBJECT_DETECTION
      • NLP.TEXT_CLASSIFICATION

      ai.djl.Criteria Class

      Criteria is a builder for creating criteria objects used to filter and load models.

      Key Methods:
      • static Builder<?, ?> builder(): Creates a new builder instance.
      • Criteria<I, O> optApplication(Application application): Sets the application type.
      • Criteria<I, O> optEngine(String engine): Specifies the engine to use (e.g., MXNet, PyTorch)
      example
      
      import ai.djl.Model
      import ai.djl.ModelException
      import ai.djl.modality.Classifications
      import ai.djl.modality.cv.Image
      import ai.djl.modality.cv.ImageFactory
      import ai.djl.ndarray.NDList
      import ai.djl.translate.TranslateException
      import ai.djl.translate.Translator
      import ai.djl.translate.TranslatorContext
      import java.io.IOException
      import java.nio.file.Paths
      
      
      object DjlExample {
          @JvmStatic
          fun main(args: Array<String>) {
              // 模型路径
              val modelDir = Paths.get("models")
              val modelName = "resnet18"
              try {
                  Model.newInstance(modelName).use { model ->
                      // 加载模型
                      model.load(modelDir)
      
                      // 加载输入图像
                      val img =
                          ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg"))
      
                      // 获取预测器
                      val predictor =
                          model.newPredictor(MyTranslator())
      
                      // 执行推理
                      val result = predictor.predict(img)
                      println(result)
                  }
              } catch (e: IOException) {
                  e.printStackTrace()
              } catch (e: ModelException) {
                  e.printStackTrace()
              } catch (e: TranslateException) {
                  e.printStackTrace()
              }
          }
      
          // 自定义 Translator
          private class MyTranslator : Translator<Image?, Classifications?> {
              override fun processInput(ctx: TranslatorContext?, input: Image?): NDList {
                  return NDList(input!!.toNDArray(ctx!!.ndManager))
              }
      
              override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications {
                  val probabilitiesNDArray = list.singletonOrThrow().softmax(1)
                  val labels: List<String> = List(100) { "name$it" }
                  return Classifications(labels, probabilitiesNDArray)
              }
          }
      }
      
      
      
2. Dataset
  • 常见的数据集类型:

    1. RandomAccessDataset:
      • RandomAccessDataset 是一种基本的数据集接口,适用于数据可以随机访问的情况,如数组或列表。
      • 它支持批处理(batching)、数据切片(slicing)等操作,适合大多数监督学习任务。
    2. IterableDataset:
      • IterableDataset 适用于数据不能随机访问的情况,如流数据或实时生成的数据。
      • 它通过迭代器(iterator)提供数据,适用于需要动态生成或处理的数据源。
    3. RecordDataset:
      • RecordDataset 是基于记录文件(record file)的数据集格式,常用于大规模数据处理。
      • 它可以高效地加载和处理数据记录,适用于分布式训练和大数据集的处理。

    DJL 的数据集组件提供的功能包括:

    1. 数据加载和预处理:
      • 支持从多种数据源加载数据,如本地文件、远程服务器、数据库等。
      • 提供数据预处理功能,如归一化、数据增强、特征提取等。
    2. 批处理(Batching):
      • 支持将数据分成小批次进行处理,适用于大规模数据集的训练。
      • 提供灵活的批处理策略,可根据需要进行自定义。
    3. 数据变换(Transformations):
      • 提供多种数据变换功能,如图像变换、文本处理、数值处理等。
      • 支持链式调用,将多个变换操作组合在一起,形成数据处理管道。
    4. 数据加载器(DataLoader):
      • DataLoader 负责将数据集打包成批次,并在训练过程中按需提供数据。
      • 支持多线程数据加载,提高数据处理效率。
  • Dataset:定义数据集的抽象类,用户可以继承该类来实现自定义的数据集。

    • import ai.djl.Model;
      import ai.djl.ModelException;
      import ai.djl.inference.Predictor;
      import ai.djl.modality.Classifications;
      import ai.djl.modality.cv.Image;
      import ai.djl.modality.cv.ImageFactory;
      import ai.djl.repository.zoo.Criteria;
      import ai.djl.repository.zoo.ModelZoo;
      import ai.djl.translate.TranslateException;
      
      import java.io.IOException;
      import java.nio.file.Paths;
      
      public class DjlExample {
          public static void main(String[] args) throws IOException, ModelException, TranslateException {
              // 加载模型
              Criteria<Image, Classifications> criteria = Criteria.builder()
                      .optEngine("TensorFlow") // 选择引擎
                      .setTypes(Image.class, Classifications.class)
                      .optModelPath(Paths.get("path/to/model"))
                      .build();
              
              try (Model model = ModelZoo.loadModel(criteria);
                   Predictor<Image, Classifications> predictor = model.newPredictor()) {
                  
                  // 加载图像
                  Image img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg"));
                  
                  // 进行推理
                  Classifications result = predictor.predict(img);
                  System.out.println(result);
              }
          }
      }
      
      
    • import ai.djl.Application;
      import ai.djl.Model;
      import ai.djl.basicdataset.cv.classification.FashionMnist;
      import ai.djl.engine.Engine;
      import ai.djl.metric.Metrics;
      import ai.djl.ndarray.NDArray;
      import ai.djl.ndarray.NDManager;
      import ai.djl.training.DefaultTrainingConfig;
      import ai.djl.training.EasyTrain;
      import ai.djl.training.Trainer;
      import ai.djl.training.dataset.Batch;
      import ai.djl.training.dataset.Dataset;
      import ai.djl.training.listener.TrainingListener;
      import ai.djl.training.loss.Loss;
      import ai.djl.training.optimizer.Optimizer;
      import ai.djl.training.tracker.Tracker;
      import ai.djl.translate.TranslateException;
      import ai.djl.util.Pair;
      
      import java.io.IOException;
      
      public class DJLDatasetExample {
      
          public static void main(String[] args) throws IOException, TranslateException {
              NDManager manager = NDManager.newBaseManager();
      
              FashionMnist fashionMnist = FashionMnist.builder()
                      .optUsage(Dataset.Usage.TRAIN)
                      .setSampling(32, true) // 32 is the batch size
                      .optLimit(Long.MAX_VALUE) // Use this to limit the number of samples
                      .build();
      
              fashionMnist.prepare();
      
              Model model = Model.newInstance("fashion-mnist-model");
              TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                      .optOptimizer(Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build())
                      .addTrainingListeners(TrainingListener.Defaults.logging());
      
              try (Trainer trainer = model.newTrainer(config)) {
                  trainer.initialize(new long[]{1, 28, 28}); // Example shape for image data
      
                  Metrics metrics = new Metrics();
                  trainer.setMetrics(metrics);
      
                  for (Batch batch : trainer.iterateDataset(fashionMnist)) {
                      EasyTrain.trainBatch(trainer, batch);
                      trainer.step();
                      batch.close();
                  }
      
                  trainer.notifyListeners(listener -> listener.onTrainingEnd(trainer));
              }
          }
      }
      
      

3. Engine 和 NDManager
  • Engine:DJL支持多个深度学习引擎,如MXNet、PyTorch、ONNX、TensorFlow,Engine接口提供统一的抽象,方便切换底层引擎。

  • NDManager:管理NDArray,用于处理多维数组,封装了底层的数组操作。

    Using DJL Engine
    
    import ai.djl.Model
    import ai.djl.ModelException
    import ai.djl.ndarray.NDArray
    import ai.djl.ndarray.NDList
    import ai.djl.ndarray.types.Shape
    import ai.djl.translate.Batchifier
    import ai.djl.translate.TranslateException
    import ai.djl.translate.Translator
    import ai.djl.translate.TranslatorContext
    import java.io.IOException
    import java.nio.file.Paths
    
    object DJLEngineExample {
        @Throws(ModelException::class, TranslateException::class, IOException::class)
        @JvmStatic
        fun main(args: Array<String>) {
            // Initialize the model
            val model = Model.newInstance("model-name", "ai.djl.pytorch") // Assuming "model-name" is valid and using PyTorch engine
    
            // Load a pre-trained model
            model.load(Paths.get("path/to/your/model")) // Ensure the path is correct
    
            // Define a translator for data preprocessing and postprocessing
            val translator: Translator<Array<Float>, Float> = object : Translator<Array<Float>, Float> {
                override fun processInput(ctx: TranslatorContext, input: Array<Float>): NDList {
                    val manager = ctx.ndManager
                    val array: NDArray = manager.create(input.toFloatArray()).reshape(Shape(1, input.size.toLong())) // Reshape might be necessary
                    return NDList(array)
                }
    
                override fun processOutput(ctx: TranslatorContext, list: NDList): Float {
                    // Assuming the output is a single scalar value
                    return list[0].getFloat() // Use getFloat() to get the scalar value
                }
    
                override fun getBatchifier(): Batchifier? {
                    return null // Or implement batching if needed
                }
            }
    
            model.newPredictor(translator).use { predictor ->
                val input = arrayOf(1.0f, 2.0f, 3.0f) // Input should match the model's expected input shape
                val output = predictor.predict(input)
                println("Prediction: $output")
            }
        }
    }
    
    
    Overview of NDManager
    Key Features of NDManager:
    1. Memory Management: Automates the process of memory allocation and deallocation for NDArrays.
    2. Resource Scope: NDArrays created by an NDManager are tied to the lifecycle of that manager. When the manager is closed, all associated NDArrays are also released.
    3. Hierarchical Structure: NDManagers can create child managers, which can further manage their own NDArrays. This is useful for managing resources in complex workflows.
    Using NDManager
    
    import ai.djl.ndarray.NDManager
    
    
    object NDManagerExample {
        @JvmStatic
        fun main(args: Array<String>) {
            NDManager.newBaseManager().use { manager ->
                val array = manager.create(floatArrayOf(1.0f, 2.0f, 3.0f))
                println("Array: $array")
    
                // Perform operations
                val result = array.add(2.0f)
                println("Result: $result")
            }
            // No need to explicitly free the memory, it's handled by the NDManager
        }
    }
    
4. Trainer 和 Predictor
  • Trainer 类

    提供训练模型的接口,包含优化器、损失函数和训练循环等功能。用于训练深度学习模型。它封装了训练过程中的一些常见操作,如前向传播、反向传播和参数更新。

    主要功能包括:

    • 模型的训练和验证
    • 管理优化器和损失函数
    • 提供易于使用的训练循环
    代码演示

    以下是使用 DJL 的 Trainer 类训练一个简单神经网络的示例代码:

    
    import ai.djl.Model
    import ai.djl.basicdataset.cv.classification.FashionMnist
    import ai.djl.basicmodelzoo.basic.Mlp
    import ai.djl.ndarray.types.Shape
    import ai.djl.training.DefaultTrainingConfig
    import ai.djl.training.TrainingConfig
    import ai.djl.training.dataset.Dataset
    import ai.djl.training.dataset.RandomAccessDataset
    import ai.djl.training.listener.LoggingTrainingListener
    import ai.djl.training.listener.TrainingListener
    import ai.djl.training.loss.Loss
    import ai.djl.training.optimizer.Optimizer
    import ai.djl.training.tracker.FixedPerVarTracker
    import ai.djl.training.util.ProgressBar
    import ai.djl.translate.TranslateException
    import java.io.IOException
    import java.nio.file.Paths
    
    object DjlTrainerDemo {
        @Throws(IOException::class, TranslateException::class)
        @JvmStatic
        fun main(args: Array<String>) {
            // Load dataset
            val trainDataset: RandomAccessDataset =
                FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32, true).build()
            trainDataset.prepare(ProgressBar())
    
            // Define model
            val model = Model.newInstance("mlp")
            model.block = Mlp(28 * 28, 10, intArrayOf(128, 64))
    
            // Define training configuration
            val config: TrainingConfig = DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                .optOptimizer(
                    Optimizer.sgd()
                        .setLearningRateTracker(
                            FixedPerVarTracker.builder()
                                .setDefaultValue(0.01f)
                                .build()
                        ).build()
                )
                .addTrainingListeners(LoggingTrainingListener())
    
            model.newTrainer(config).use { trainer ->
                trainer.initialize(Shape(1, (28 * 28).toLong()))
                for (epoch in 0..9) {
                    for (batch in trainer.iterateDataset(trainDataset)) {
                        trainer.step()
                        batch.close()
                    }
                    trainer.notifyListeners { listener: TrainingListener ->
                        listener.onEpoch(trainer)
                    }
                }
                model.save(Paths.get("model"), "mlp")
            }
        }
    }
    
    
    Predictor 类

    用于模型推理,接收输入数据并返回预测结果。用于对训练好的模型进行推理。它提供了一个简单的接口,用于将输入数据传递给模型并获取预测结果。

    主要功能包括:

    • 加载模型进行推理
    • 处理输入和输出数据的转换
    代码演示
    
    import ai.djl.Model
    import ai.djl.modality.Classifications
    import ai.djl.ndarray.NDArray
    import ai.djl.ndarray.NDList
    import ai.djl.ndarray.NDManager
    import ai.djl.ndarray.types.Shape
    import ai.djl.translate.Batchifier
    import ai.djl.translate.TranslateException
    import ai.djl.translate.Translator
    import ai.djl.translate.TranslatorContext
    import java.io.IOException
    import java.nio.file.Paths
    
    object DjlPredictorDemo {
        @Throws(IOException::class, TranslateException::class)
        @JvmStatic
        fun main(args: Array<String>) {
            // Load model
            val model = Model.newInstance("mlp")
            model.load(Paths.get("model"), "mlp")
    
            // Define Translator
            val translator: Translator<NDArray, Classifications> = object : Translator<NDArray, Classifications> {
                override fun processInput(ctx: TranslatorContext, input: NDArray): NDList {
                    return NDList(input.reshape(Shape(1, (28 * 28).toLong())))
                }
    
                override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications {
                    // Assuming the output NDArray is the first element in NDList
                    val probabilities = list.singletonOrThrow()
                    return Classifications(listOf("Label1", "Label2"), probabilities) // Example labels
                }
    
                override fun getBatchifier(): Batchifier {
                    return Batchifier.STACK
                }
            }
    
            model.newPredictor(translator).use { predictor ->
                val manager = NDManager.newBaseManager()
                val array = manager.ones(Shape(1, (28 * 28).toLong()))
                val classifications = predictor.predict(array)
                println(classifications)
            }
        }
    }