SDK 使用文档
AVA SDK 使用指南
功能介绍
AVA SDK 为 AVA 平台的使用者,在训练任务运行环境中提供以下功能:
- 获取训练信息
- 读取训练配置
- 训练指标监控上报
- 上传 / 下载文件
AVA SDK 除了“训练指标监控上报”部分与具体训练框架有关,其余都是训练框架无关的。目前监控部分,AVA SDK 支持的训练框架有:
- Caffe
- MXNet
安装
通过 pip 直接在线安装
初次安装
pip install ava-sdk
升级
pip install --upgrade ava-sdk
使用样例
训练实例
在 AVA 平台中,在一个训练任务的运行环境中,可以进行多次训练,每次训练被称为一个“训练实例”。在训练中,用户需要调用 SDK API 来进行训练实例的创建和状态更新。样例如下:
from ava.train import base as train
def main():
# 初始化一个训练实例
train_ins = train.TrainInstance()
# 进行训练相关的操作
# ...
# 训练结束,更新训练实例状态,err_msg为空时表示训练正常结束,不为空表示训练异常结束
train_ins.done(err_msg=err_msg)
训练参数
在训练过程中,会有许多参数需要配置,大部分情况为训练超参,AVA SDK 支持直接从容器环境中读取用户自定义的参数文件,文件格式为 Json。 注:设计的场景是,用户在创建/更新训练任务时,可以将参数文件由 Portal 传递至运行环境,之后再由 SDK 读取到。但是,目前 Portal 暂时不支持参数文件传入,所以在目前的实际使用中,用户需要在运行环境中对固定文件编辑,才可以从 SDK 中读取到。
例如,参数文件如下:
# 参数文件路径为:/workspace/params/default.json
cat /workspace/params/default.json
{
"k1": "v1",
"k2": {
"k3": "v3"
}
}
在脚本中读取参数
from ava.params import params
def main():
# 获取某项value
value = params.get_value("k1") # v1
value = params.get_value("k2") # {"k3": "v3"}
value = params.get_value("k2.k3") # v3
value = params.get_value("none", default="default_val") # default_val
# 获取所有配置
value = params.get_all() # 整个配置的 Json 结构
训练指标监控上报
AVA SDK 支持将训练过程中的指标上报至 AVA 平台,目前 AVA SDK 对于不同训练框架需要进行不同的实现:
- Caffe:通过训练 log 解析(目前仅支持通过直接启动caffe进程的方式获取指标)
- MXNet:添加 callback 函数
样例:
Caffe
from ava.train import base as train
from ava.utils import cmd
def main():
train_ins = train.TrainInstance()
fixed_solver = "/some/path/solver.prototxt"
train_args = ["caffe", "train", "-solver", fixed_solver]
# AVA-SDK 调用 caffe 进程
proc = cmd.startproc(train_args)
logger.info("Started %s", proc)
# AVA-SDK 添加指标监控,获取 caffe 进程的输出并解析/上报训练指标
cmd.logproc(self.proc, [
train_ins.get_monitor_callback("caffe"), # 解析 && 上报 batch 粒度监控指标
train_ins.get_epoch_end_callback("caffe", # epoch 粒度 callback,用于上报 snapshot model
batch_of_epoch=xxx,
other_files=[])
])
exit_code = self.proc.wait()
logger.info("Finished proc with code %s", exit_code)
train_ins.done()
MXNet
import mxnet as mx
from ava.monitor import mxnet as mxnet_monitor
from ava.train import base as train
def main():
train_ins = train.TrainInstance()
# AVA-SDK mxnet monitor callback 初始化
batch_end_cb = self.train_ins.get_monitor_callback(
"mxnet",
batch_size=actual_batch_size,
batch_freq=batch_freq)
epoch_end_cb = [
# mxnet default epoch callback
mx.callback.do_checkpoint(
snapshot_prefix, snapshot_interval_epochs),
self.train_ins.get_epoch_end_callback(
"mxnet", batch_of_epoch=batch_of_epoch, epoch_interval=snapshot_interval_epochs, other_files=[])
]
train_config = {
# ...
"epoch_end_callback": epoch_end_cb,
"batch_end_callback": batch_end_cb,
# ...
}
mod = mx.mod.Module(symbol=sym, context=ctx)
mod.fit(train_data, eval_data=val_data, **train_config)
train_ins.done()
上传 / 下载文件
略,见 AVA SDK API
AVA SDK API
SDK 包结构如下:
./ava
├── params # 配置获取相关
│ ├── params.py
├── train # 训练实例相关
│ ├── base.py # 训练 jobID/trainingID,训练实例默认路径等
├── log # 日志
│ ├── logger.py
├── monitor # 监控相关
│ ├── mxnet.py # batch 粒度 callback 函数,指标上报
└── utils # 工具类
├── cmd.py # 运行子进程
├── fetcher.py # 下载文件,支持七牛私有协议
├── fileloc.py # 创建 / 删除目录
├── uploader.py # 上传文件
├── utils.py # 其他工具函数
params
params.params
FUNCTIONS
get_all()
获取训练params所有数据
Returns:
params: 解析配置文件的 dict,如果解析失败,返回空 dict
get_value(key, is_required=False, default=None, use_cache=True)
获取训练params信息
key 以 '.' 分割,例如 'solverOps.solverPolicy.stepvalue' 代表
{
...
'solverOps': {
'solverPolicy': {
'stepvalue': <value>
}
}
}
Args:
key: 以 '.' 分割的配置名称。
is_required: key 不存在时异常。
default: key 不存在时返回默认值。
Returns:
配置文件中 key 对应的内容。
train
train.base
CLASSES
__builtin__.object
TrainInstance
class TrainInstance(__builtin__.object)
| 训练任务实例
| 每一个训练实例,都会有对应的独立输出目录
|
| Methods defined here:
|
| __init__(self, training_ins_id=None)
| 初始化与训练相关的路径
|
| 训练实例的工作目录:
| /workspace
| ├── params
| | └── default.json // 用户传递的参数json文件
| ├── data
| │ ├── test // test data pvc
| │ ├── train // train data pvc
| │ └── val // val data pvc
| ├── model // 存放模型的目录
| └── run // run data pvc
| ├── <training_ins_id1>
| | ├── output // 一次训练的 snapshot model 输出
| | └── logs // 一次训练的日志路径
| └── <training_ins_id2>
|
| Args:
| training_ins_id: 用户传入的训练ID,表明:1)用户复用之前的训练ID 或者 2)用户自定义的训练ID
|
| done(self, err_msg='')
|
| get_monitor_callback(self, framework, **kwargs)
| 返回各种框架的监控回调函数
|
| Args:
| framework: 各个训练框架名称,目前支持:mxnet
| kwargs: 各个训练框架 callback 的初始化函数
|
| Returns:
| callback: 各个训练框架的回调函数, or None 表示框架不支持
|
| get_model_base_path(self)
|
| get_snapshot_base_path(self)
|
| get_testset_base_path(self)
|
| get_training_ins_id(self)
|
| get_trainset_base_path(self)
|
| get_valset_base_path(self)
log
log.logger
CLASSES
logging.Handler(logging.Filterer)
MultiProcessingLogHandler
class MultiProcessingLogHandler(logging.Handler)
| 支持多进程的 LogHandler, 最终将 log 打入 RotatingFileHandler 中
|
| Method resolution order:
| MultiProcessingLogHandler
| logging.Handler
| logging.Filterer
| __builtin__.object
|
| Methods defined here:
|
| __init__(self, filename, mode='w', maxBytes=0, backupCount=10)
|
| close(self)
|
| emit(self, record)
| 实现 Handler 的 emit 方法,将消息丢进 queue 中
|
| receive(self)
|
| send(self, signal)
|
| setFormatter(self, fmt)
|
| ----------------------------------------------------------------------
| Methods inherited from logging.Handler:
|
| acquire(self)
| Acquire the I/O thread lock.
|
| createLock(self)
| Acquire a thread lock for serializing access to the underlying I/O.
|
| flush(self)
| Ensure all logging output has been flushed.
|
| This version does nothing and is intended to be implemented by
| subclasses.
|
| format(self, record)
| Format the specified record.
|
| If a formatter is set, use it. Otherwise, use the default formatter
| for the module.
|
| get_name(self)
|
| handle(self, record)
| Conditionally emit the specified logging record.
|
| Emission depends on filters which may have been added to the handler.
| Wrap the actual emission of the record with acquisition/release of
| the I/O thread lock. Returns whether the filter passed the record for
| emission.
|
| handleError(self, record)
| Handle errors which occur during an emit() call.
|
| This method should be called from handlers when an exception is
| encountered during an emit() call. If raiseExceptions is false,
| exceptions get silently ignored. This is what is mostly wanted
| for a logging system - most users will not care about errors in
| the logging system, they are more interested in application errors.
| You could, however, replace this with a custom handler if you wish.
| The record which was being processed is passed in to this method.
|
| release(self)
| Release the I/O thread lock.
|
| setLevel(self, level)
| Set the logging level of this handler.
|
| set_name(self, name)
|
| ----------------------------------------------------------------------
| Data descriptors inherited from logging.Handler:
|
| name
|
| ----------------------------------------------------------------------
| Methods inherited from logging.Filterer:
|
| addFilter(self, filter)
| Add the specified filter to this handler.
|
| filter(self, record)
| Determine if a record is loggable by consulting all the filters.
|
| The default is to allow the record to be logged; any filter can veto
| this and the record is then dropped. Returns a zero value if a record
| is to be dropped, else non-zero.
|
| removeFilter(self, filter)
| Remove the specified filter from this handler.
|
| ----------------------------------------------------------------------
| Data descriptors inherited from logging.Filterer:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
FUNCTIONS
initialize_logger(workdir, node=None, debug=False)
初始化 logger 。
initialize_test_logger()
初始化 unittest logger 。
monitor
monitor.mxnet
FUNCTIONS
full_mxnet_metrics()
mxnet 常用 metric 列表
utils
utils.cmd
FUNCTIONS
runproc(args, log_hook=None)
运行 subprocess 。
1. 启动 subprocess 并处理日志。
2. 等待运行结束。
Args:
args: Subprocess args.
log_hook: One or multiple function(s). 对截取的每行日志处理分析。
Returns:
subprocess exit_code.
utils.fetcher
CLASSES
class Fetcher(__builtin__.object)
| 网络请求的封装,支持同步和异步两种方式,且支持使用七牛源站下载和失败重试。
|
| Methods defined here:
|
| __del__(self)
|
| __init__(self, _async=False, access_key=None, secret_key=None, req_timeout=None, download_pool=None, max_buffer_size=None)
| 初始化 Fetcher
|
| Args:
| _async: 使用同步请求还是下异步请求,默认使用同步请求。
| access_key: 七牛账号的 access_key,默认为 None。
| secret_key: 七牛账号的 secret_key,默认为 None。
| req_timeout: 超时时间,默认为 None。
| download_pool: 请求最大并发数,默认为 None。
| max_buffer_size: 最大下载大小,默认为 None。
|
| fetch(self, url, method='GET', headers=None, data=None, from_kodo=False, retry=0, max_retry=-1, callback=None, cb_args=None, cb_kargs=None)
| 发起网络请求
| Args:
| url: 请求地址。
| method: 请求方法,默认为 GET 方法。
| headers: 请求的 Headers,默认为 None。
| data: 请求的 body,默认为 None。注意 method 为 GET 时,不支持 data 参数。
| from_kodo: 是否从七牛源站下载,默认为 False
| retry: 当前重试的次数,默认为 0。max_retry 为负时无效。
| max_retry: 允许的最多重试次数,默认为 -1。为负时表示不重试,为 0 时表示使用默认最大重试次数。
| callback: 请求结束后的回调,默认为 None。只在异步 Fetcher 实例下有效。
| cb_args: 传入 callback 的参数。
| cb_kargs: 传入 callback 的参数。
|
| Returns:
| (res, err): 请求的响应对象(tornado.httpclient.HTTPResponse)
| 和请求错误对象(tornado.httpclient.HTTPError)。
| 只有同步 Fetcher 实例有返回。
|
| parse_qiniu_url(self, path)
| convert the qiniu://zone/bucket/key to http://domain/key
| :param path:path 是/bucket/key
| :return: normal uri
|
| start_tornado_ioloop(self)
|
| stop_tornado_ioloop(self)
|
| ----------------------------------------------------------------------
| Data descriptors defined here:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
|
| ----------------------------------------------------------------------
| Data and other attributes defined here:
|
| DEFAULT_CONNECT_TIMEOUT = 20
|
| DEFAULT_MAX_BUFFER_SIZE = 1073741824
|
| DEFAULT_MAX_RETRY = 5
|
| DEFAULT_REQUEST_TIMEOUT = 600
|
| DOWNLOAD_POOL_SIZE = 75
class QMacAuthFetcher(__builtin__.object)
| 七牛带鉴权服务请求的封装
|
| Methods defined here:
|
| __init__(self, access_key, secret_key)
| Args:
| access_key: 用户的access_key
| secret_key: 用户的secret_key
|
| fetch(self, url, method='GET', data=None, headers=None, retry=0, max_retry=-1)
| Args:
| url: 请求地址
| method: 请求方法
| data: 请求数据
| headers: 请求头部
| retry: 当前重试次数,默认为0
| max_retry: 最大重试次数,默认为-1表示不重试。为0时表示使用此Fetcher的默认重试次数。
|
| Returns:
| (res, err) 请求的响应对象(tornado.httpclient.HTTPResponse)
| 和请求错误对象(tornado.httpclient.HTTPError)。
|
| ----------------------------------------------------------------------
| Data descriptors defined here:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
|
| ----------------------------------------------------------------------
| Data and other attributes defined here:
|
| DEFAULT_MAX_RETRY = 5
utils.uploader
CLASSES
__builtin__.object
Uploader
class Uploader(__builtin__.object)
| 将文件上报至Kodo
|
| Methods defined here:
|
| __init__(self, ak, sk, bucket=None)
|
| set_bucket(self, bucket)
|
| upload_file(self, filepath, key, bucket=None)
| 上报单个文件至Kodo bucket中
| Args:
| filepath: 文件路径
| key: bucket key
| bucket: bucket名称,不提供的话,使用default_bucket
|
| Returns:
| return None when it's ok, or raise exception
|
| Raises:
| Exception, 当上报文件时发生错误,则raise Exception
文档反馈
(如有产品使用问题,请 提交工单)