深度学习平台

  • 深度学习平台 > 使用文档 > SDK 使用文档

    SDK 使用文档

    最近更新时间: 2018-03-05 08:12:16

    AVA SDK 使用指南

    功能介绍

    AVA SDK 为 AVA 平台的使用者,在训练任务运行环境中提供以下功能:

    • 获取训练信息
    • 读取训练配置
    • 训练指标监控上报
    • 上传 / 下载文件

    AVA SDK 除了“训练指标监控上报”部分与具体训练框架有关,其余都是训练框架无关的。目前监控部分,AVA SDK 支持的训练框架有:

    • Caffe
    • MXNet

    安装

    通过 pip 直接在线安装

    初次安装

    pip install ava-sdk
    

    升级

    pip install --upgrade ava-sdk
    

    使用样例

    训练实例

    在 AVA 平台中,在一个训练任务的运行环境中,可以进行多次训练,每次训练被称为一个“训练实例”。在训练中,用户需要调用 SDK API 来进行训练实例的创建和状态更新。样例如下:

    from ava.train import base as train
    
    def main():
        # 初始化一个训练实例
        train_ins = train.TrainInstance()
    
        # 进行训练相关的操作
        # ...
    
        # 训练结束,更新训练实例状态,err_msg为空时表示训练正常结束,不为空表示训练异常结束
        train_ins.done(err_msg=err_msg)
    

    训练参数

    在训练过程中,会有许多参数需要配置,大部分情况为训练超参,AVA SDK 支持直接从容器环境中读取用户自定义的参数文件,文件格式为 Json。 注:设计的场景是,用户在创建/更新训练任务时,可以将参数文件由 Portal 传递至运行环境,之后再由 SDK 读取到。但是,目前 Portal 暂时不支持参数文件传入,所以在目前的实际使用中,用户需要在运行环境中对固定文件编辑,才可以从 SDK 中读取到。

    例如,参数文件如下:

    # 参数文件路径为:/workspace/params/default.json
    cat /workspace/params/default.json
    {
        "k1": "v1",
        "k2": {
            "k3": "v3"
        }
    }
    

    在脚本中读取参数

    from ava.params import params
    
    def main():
        # 获取某项value
        value = params.get_value("k1")  # v1
        value = params.get_value("k2")  # {"k3": "v3"}
        value = params.get_value("k2.k3")  # v3
        value = params.get_value("none", default="default_val")  # default_val
        # 获取所有配置
        value = params.get_all()    # 整个配置的 Json 结构
    

    训练指标监控上报

    AVA SDK 支持将训练过程中的指标上报至 AVA 平台,目前 AVA SDK 对于不同训练框架需要进行不同的实现:

    • Caffe:通过训练 log 解析(目前仅支持通过直接启动caffe进程的方式获取指标)
    • MXNet:添加 callback 函数

    样例:

    Caffe

    from ava.train import base as train
    from ava.utils import cmd
    
    def main():
        train_ins = train.TrainInstance()
        fixed_solver = "/some/path/solver.prototxt"
        train_args = ["caffe", "train", "-solver", fixed_solver]
        # AVA-SDK 调用 caffe 进程
        proc = cmd.startproc(train_args)
        logger.info("Started %s", proc)
        # AVA-SDK 添加指标监控,获取 caffe 进程的输出并解析/上报训练指标
        cmd.logproc(self.proc, [
            train_ins.get_monitor_callback("caffe"),    # 解析 && 上报 batch 粒度监控指标
            train_ins.get_epoch_end_callback("caffe",   # epoch 粒度 callback,用于上报 snapshot model
                                             batch_of_epoch=xxx,
                                             other_files=[])
        ])
        exit_code = self.proc.wait()
        logger.info("Finished proc with code %s", exit_code)
        train_ins.done()
    

    MXNet

    import mxnet as mx
    from ava.monitor import mxnet as mxnet_monitor
    from ava.train import base as train
    
    def main():
        train_ins = train.TrainInstance()
        # AVA-SDK mxnet monitor callback 初始化
        batch_end_cb = self.train_ins.get_monitor_callback(
            "mxnet",
            batch_size=actual_batch_size,
            batch_freq=batch_freq)
        epoch_end_cb = [
            # mxnet default epoch callback
            mx.callback.do_checkpoint(
                snapshot_prefix, snapshot_interval_epochs),
            self.train_ins.get_epoch_end_callback(
                "mxnet", batch_of_epoch=batch_of_epoch, epoch_interval=snapshot_interval_epochs, other_files=[])
        ]
    
        train_config = {
            # ...
            "epoch_end_callback": epoch_end_cb,
            "batch_end_callback": batch_end_cb,
            # ...
        }
    
        mod = mx.mod.Module(symbol=sym, context=ctx)
        mod.fit(train_data, eval_data=val_data, **train_config)
    
        train_ins.done()
    

    上传 / 下载文件

    略,见 AVA SDK API

    AVA SDK API

    SDK 包结构如下:

    ./ava
    ├── params          # 配置获取相关
    │   ├── params.py
    ├── train           # 训练实例相关
    │   ├── base.py         # 训练 jobID/trainingID,训练实例默认路径等
    ├── log             # 日志
    │   ├── logger.py
    ├── monitor         # 监控相关
    │   ├── mxnet.py        # batch 粒度 callback 函数,指标上报
    └── utils           # 工具类
        ├── cmd.py          # 运行子进程
        ├── fetcher.py      # 下载文件,支持七牛私有协议
        ├── fileloc.py      # 创建 / 删除目录
        ├── uploader.py     # 上传文件
        ├── utils.py        # 其他工具函数
    

    params

    params.params

    FUNCTIONS
        get_all()
            获取训练params所有数据
    
            Returns:
                params: 解析配置文件的 dict,如果解析失败,返回空 dict
    
        get_value(key, is_required=False, default=None, use_cache=True)
            获取训练params信息
    
            key 以 '.' 分割,例如 'solverOps.solverPolicy.stepvalue' 代表
            {
                ...
                'solverOps': {
                    'solverPolicy': {
                        'stepvalue': <value>
                    }
                }
            }
    
            Args:
                key: 以 '.' 分割的配置名称。
                is_required: key 不存在时异常。
                default: key 不存在时返回默认值。
    
            Returns:
                配置文件中 key 对应的内容。
    

    train

    train.base

    CLASSES
        __builtin__.object
            TrainInstance
    
        class TrainInstance(__builtin__.object)
         |  训练任务实例
         |  每一个训练实例,都会有对应的独立输出目录
         |
         |  Methods defined here:
         |
         |  __init__(self, training_ins_id=None)
         |      初始化与训练相关的路径
         |
         |      训练实例的工作目录:
         |      /workspace
         |          ├── params
         |          |   └── default.json    // 用户传递的参数json文件
         |          ├── data
         |          │   ├── test            // test data pvc
         |          │   ├── train           // train data pvc
         |          │   └── val             // val data pvc
         |          ├── model               // 存放模型的目录
         |          └── run                 // run data pvc
         |              ├── <training_ins_id1>
         |              |   ├── output      // 一次训练的 snapshot model 输出
         |              |   └── logs        // 一次训练的日志路径
         |              └── <training_ins_id2>
         |
         |      Args:
         |          training_ins_id:    用户传入的训练ID,表明:1)用户复用之前的训练ID 或者 2)用户自定义的训练ID
         |
         |  done(self, err_msg='')
         |
         |  get_monitor_callback(self, framework, **kwargs)
         |      返回各种框架的监控回调函数
         |
         |      Args:
         |          framework:  各个训练框架名称,目前支持:mxnet
         |          kwargs:     各个训练框架 callback 的初始化函数
         |
         |      Returns:
         |          callback:   各个训练框架的回调函数, or None 表示框架不支持
         |
         |  get_model_base_path(self)
         |
         |  get_snapshot_base_path(self)
         |
         |  get_testset_base_path(self)
         |
         |  get_training_ins_id(self)
         |
         |  get_trainset_base_path(self)
         |
         |  get_valset_base_path(self)
    

    log

    log.logger

    CLASSES
        logging.Handler(logging.Filterer)
            MultiProcessingLogHandler
    
        class MultiProcessingLogHandler(logging.Handler)
         |  支持多进程的 LogHandler, 最终将 log 打入 RotatingFileHandler 中
         |  
         |  Method resolution order:
         |      MultiProcessingLogHandler
         |      logging.Handler
         |      logging.Filterer
         |      __builtin__.object
         |  
         |  Methods defined here:
         |  
         |  __init__(self, filename, mode='w', maxBytes=0, backupCount=10)
         |  
         |  close(self)
         |  
         |  emit(self, record)
         |      实现 Handler 的 emit 方法,将消息丢进 queue 中
         |  
         |  receive(self)
         |  
         |  send(self, signal)
         |  
         |  setFormatter(self, fmt)
         |  
         |  ----------------------------------------------------------------------
         |  Methods inherited from logging.Handler:
         |  
         |  acquire(self)
         |      Acquire the I/O thread lock.
         |  
         |  createLock(self)
         |      Acquire a thread lock for serializing access to the underlying I/O.
         |  
         |  flush(self)
         |      Ensure all logging output has been flushed.
         |      
         |      This version does nothing and is intended to be implemented by
         |      subclasses.
         |  
         |  format(self, record)
         |      Format the specified record.
         |      
         |      If a formatter is set, use it. Otherwise, use the default formatter
         |      for the module.
         |  
         |  get_name(self)
         |  
         |  handle(self, record)
         |      Conditionally emit the specified logging record.
         |      
         |      Emission depends on filters which may have been added to the handler.
         |      Wrap the actual emission of the record with acquisition/release of
         |      the I/O thread lock. Returns whether the filter passed the record for
         |      emission.
         |  
         |  handleError(self, record)
         |      Handle errors which occur during an emit() call.
         |      
         |      This method should be called from handlers when an exception is
         |      encountered during an emit() call. If raiseExceptions is false,
         |      exceptions get silently ignored. This is what is mostly wanted
         |      for a logging system - most users will not care about errors in
         |      the logging system, they are more interested in application errors.
         |      You could, however, replace this with a custom handler if you wish.
         |      The record which was being processed is passed in to this method.
         |  
         |  release(self)
         |      Release the I/O thread lock.
         |  
         |  setLevel(self, level)
         |      Set the logging level of this handler.
         |  
         |  set_name(self, name)
         |  
         |  ----------------------------------------------------------------------
         |  Data descriptors inherited from logging.Handler:
         |  
         |  name
         |  
         |  ----------------------------------------------------------------------
         |  Methods inherited from logging.Filterer:
         |  
         |  addFilter(self, filter)
         |      Add the specified filter to this handler.
         |  
         |  filter(self, record)
         |      Determine if a record is loggable by consulting all the filters.
         |      
         |      The default is to allow the record to be logged; any filter can veto
         |      this and the record is then dropped. Returns a zero value if a record
         |      is to be dropped, else non-zero.
         |  
         |  removeFilter(self, filter)
         |      Remove the specified filter from this handler.
         |  
         |  ----------------------------------------------------------------------
         |  Data descriptors inherited from logging.Filterer:
         |  
         |  __dict__
         |      dictionary for instance variables (if defined)
         |  
         |  __weakref__
         |      list of weak references to the object (if defined)
    
    FUNCTIONS
        initialize_logger(workdir, node=None, debug=False)
            初始化 logger 。
    
        initialize_test_logger()
            初始化 unittest logger 。
    

    monitor

    monitor.mxnet

    FUNCTIONS
        full_mxnet_metrics()
            mxnet 常用 metric 列表
    

    utils

    utils.cmd

    FUNCTIONS
        runproc(args, log_hook=None)
            运行 subprocess 。
    
            1. 启动 subprocess 并处理日志。
            2. 等待运行结束。
    
            Args:
                args: Subprocess args.
                log_hook: One or multiple function(s). 对截取的每行日志处理分析。
    
            Returns:
                subprocess exit_code.
    

    utils.fetcher

    CLASSES
    
        class Fetcher(__builtin__.object)
         |  网络请求的封装,支持同步和异步两种方式,且支持使用七牛源站下载和失败重试。
         |  
         |  Methods defined here:
         |  
         |  __del__(self)
         |  
         |  __init__(self, _async=False, access_key=None, secret_key=None, req_timeout=None, download_pool=None, max_buffer_size=None)
         |      初始化 Fetcher
         |      
         |      Args:
         |          _async: 使用同步请求还是下异步请求,默认使用同步请求。
         |          access_key: 七牛账号的 access_key,默认为 None。
         |          secret_key: 七牛账号的 secret_key,默认为 None。
         |          req_timeout: 超时时间,默认为 None。
         |          download_pool: 请求最大并发数,默认为 None。
         |          max_buffer_size: 最大下载大小,默认为 None。
         |  
         |  fetch(self, url, method='GET', headers=None, data=None, from_kodo=False, retry=0, max_retry=-1, callback=None, cb_args=None, cb_kargs=None)
         |      发起网络请求
         |      Args:
         |          url: 请求地址。
         |          method: 请求方法,默认为 GET 方法。
         |          headers: 请求的 Headers,默认为 None。
         |          data: 请求的 body,默认为 None。注意 method 为 GET 时,不支持 data 参数。
         |          from_kodo: 是否从七牛源站下载,默认为 False
         |          retry: 当前重试的次数,默认为 0。max_retry 为负时无效。
         |          max_retry: 允许的最多重试次数,默认为 -1。为负时表示不重试,为 0 时表示使用默认最大重试次数。
         |          callback: 请求结束后的回调,默认为 None。只在异步 Fetcher 实例下有效。
         |          cb_args: 传入 callback 的参数。
         |          cb_kargs: 传入 callback 的参数。
         |      
         |      Returns:
         |          (res, err): 请求的响应对象(tornado.httpclient.HTTPResponse)
         |          和请求错误对象(tornado.httpclient.HTTPError)。
         |          只有同步 Fetcher 实例有返回。
         |  
         |  parse_qiniu_url(self, path)
         |      convert the qiniu://zone/bucket/key to http://domain/key
         |      :param path:path 是/bucket/key
         |      :return: normal uri
         |  
         |  start_tornado_ioloop(self)
         |  
         |  stop_tornado_ioloop(self)
         |  
         |  ----------------------------------------------------------------------
         |  Data descriptors defined here:
         |  
         |  __dict__
         |      dictionary for instance variables (if defined)
         |  
         |  __weakref__
         |      list of weak references to the object (if defined)
         |  
         |  ----------------------------------------------------------------------
         |  Data and other attributes defined here:
         |  
         |  DEFAULT_CONNECT_TIMEOUT = 20
         |  
         |  DEFAULT_MAX_BUFFER_SIZE = 1073741824
         |  
         |  DEFAULT_MAX_RETRY = 5
         |  
         |  DEFAULT_REQUEST_TIMEOUT = 600
         |  
         |  DOWNLOAD_POOL_SIZE = 75
    
        class QMacAuthFetcher(__builtin__.object)
         |  七牛带鉴权服务请求的封装
         |  
         |  Methods defined here:
         |  
         |  __init__(self, access_key, secret_key)
         |      Args:
         |          access_key: 用户的access_key
         |          secret_key: 用户的secret_key
         |  
         |  fetch(self, url, method='GET', data=None, headers=None, retry=0, max_retry=-1)
         |      Args:
         |          url: 请求地址
         |          method: 请求方法
         |          data: 请求数据
         |          headers: 请求头部
         |          retry: 当前重试次数,默认为0
         |          max_retry: 最大重试次数,默认为-1表示不重试。为0时表示使用此Fetcher的默认重试次数。
         |      
         |      Returns:
         |          (res, err) 请求的响应对象(tornado.httpclient.HTTPResponse)
         |          和请求错误对象(tornado.httpclient.HTTPError)。
         |  
         |  ----------------------------------------------------------------------
         |  Data descriptors defined here:
         |  
         |  __dict__
         |      dictionary for instance variables (if defined)
         |  
         |  __weakref__
         |      list of weak references to the object (if defined)
         |  
         |  ----------------------------------------------------------------------
         |  Data and other attributes defined here:
         |  
         |  DEFAULT_MAX_RETRY = 5
    

    utils.uploader

    CLASSES
        __builtin__.object
            Uploader
    
        class Uploader(__builtin__.object)
         |  将文件上报至Kodo
         |
         |  Methods defined here:
         |
         |  __init__(self, ak, sk, bucket=None)
         |
         |  set_bucket(self, bucket)
         |
         |  upload_file(self, filepath, key, bucket=None)
         |      上报单个文件至Kodo bucket中
         |      Args:
         |              filepath: 文件路径
         |              key: bucket key
         |              bucket: bucket名称,不提供的话,使用default_bucket
         |
         |      Returns:
         |              return None when it's ok, or raise exception
         |
         |      Raises:
         |              Exception, 当上报文件时发生错误,则raise Exception
    
    以上内容是否对您有帮助?
  • Close