mmselfsup调用逻辑
2022-09-09 16:42:34 2 举报
open-mmlab系列之mmselfsup自监督调用逻辑
作者其他创作
大纲/内容
dataloader
IterBaseRunner
+ register_training_hooks()
+ run() + train() + val() + resume() + save_checkpoint()
optimizer
配置参数config(dict)model,dataset,schedules,runtime
utils/register.py
传统的训练流程
openmmlab-mmselfsup
装饰
BaseAlgorithm
+ train_step() # 调用forward- forward() # 调用forward_train
OPTIMIZER_BUILDERS = Registry('optimizer builder')
将训练过程抽象到底层mmcv,以便匹配多数算法。具体某种算法的实现在实例到model中实现
DATASETS = Registry('dataset')PIPELINES = Registry('pipeline')DATASOURCE = Registry('datasource')
model
构建modelbuild_algorithm()
runner
Runner
Config
+ fromfile(filename):return cfg
EpochBaseRunner
+ train()+ run()+ run_iter() # 调用model中的train_step()进行推论
+ val()+ save_checkpoint()
while(epochs max_epoch)
dataset
datasets
构建datasetsbuild_dataset()
runner/optimizer/build.py
dataloader=build_dataloader(dataset)
model.train_step()
Hook
+ before_run(runner)+ before_train_epoch(runner)+ before_train_iter(runner)+ after_train_iter(runner)+ after_train_epoch(runner)+ before_val_epoch(runner)+ before_val_iter(runner)+ after_val_iter(runner)+ after_val_epoch(runner)+ after_run(runner)+ get_triggered_stages()
runner/hooks
将需要执行的Hook注册进cfg中配置的runner,比如“EpochBaseRunner”,最终会将hook按优先级加入到_hooks列表内,run的时候遍历调用_hooks
cnn
base.py
CustomHook
+ custom_function(runner)
train.py
runner.register_hook()
RUNNER_BUILDERS.build(cfg)
将注册的Register模块传入build_from_cfg()生成需要的对像
cfg=Config().fromfile()
cnn.bricks.registry
OPTIMIZERS = Registry('optimizer')
RUNNER = Registry('runner')
runner.register_training_hooks()
train()
train_model流程
MODELS
+ build_algorithm(cfg)+ build_backbone(cfg)+ build_head(cfg)+ build_neck(cfg)+ build_memory(cfg)
EvalHook
+ evaluation()
+ before_run(runner)+ before_train_epoch(runner)+ before_train_iter(runner)+ after_train_iter(runner)+ after_train_epoch(runner)
runner/build.py
run()
OptimizerHook
+ after_train_iter(runner)
mmcv
dataset=build_dataset(cfg.data)
CheckpointHook
+ before_run(runner)
+ after_train_iter(runner)+ after_train_epoch(runner)
config
MOCO(BaseAlgorithm)
+ forward_train() # 调用forward- forward()
RUNNER_BUILDERS = Registry('runner builder')
Register
+ build() # 编译模块(接口)+ build_func() # 编译(外部函数)
+ register_module() # 注册模块
LrUpdaterHook
+ get_lr(runner)+ get_regular_lr(runner)+ get_warmup_lr(runner)
+ before_run(runner)+ after_train_iter(runner)+ after_train_epoch(runner)
mmselfsup/mmselfsup
BaseRunner
+ run()# 虚函数,在继承类中实现+ train()# 虚函数,在继承类中实现+ val()# 虚函数,在继承类中实现+ save_checkpoint()
- _hooks+ register_hook():return _hooks+ call_hook()+ resume()+ register_lr_hook(cfg)+ register_momentum_hook(cfg)+ register_optimizer_hook(cfg)+ register_checkpoint_hook(cfg)+ register_logger_hook(cfg)+ register_timer_hook(cfg)+ register_custom_hook(cfg)+ register_training_hook(cfg)
HOOKS = Registry('hook')
mmselfsup/tools
MMCV_ATTENTION=Register('attention')
logger
api/train.py
0 条评论
下一页