fightclub

一个模块化的命令行解析器实现方案

问题背景

大约在上个月,有这样一个需求,需要为HPC的用户提供一个统一的任务提交脚本。用户可以使用这个脚本提交任务计算。我们实际使用的任务调度系统是PBSpro。因此这个主要的功能是:

  • 为各种不同的计算软件提供一个一致的入口,对复杂的pbs提交参数进行检测
  • 在任务提交前对用户参数进行校验和检查
  • 在任务提交前根据当前队列的状态,优化提交参数

这是一个简单的需求。我个人比较喜欢argparse这个库。能够很容易实现各种模式的命令行参数。当然也有一些同学喜欢click, Fire或者是optparse,今天就不多做讨论了。

最初的版本大概是这样子的

pbs_sub.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
from glob import glob
from argparse import ArgumentParser
from copy import deepcopy
import logging
import json
import random
import pwd
import sh
__version__ = '1.1.0'
PBS_SERVER = 'p-shhq-hpc-pbs-m01'
QSUB = sh.Command('/opt/pbs/bin/qsub')
PBSNODES = sh.Command('/opt/pbs/bin/pbsnodes')
powerflow_run = '/home/hpcsw/PBSsubmit/run.powerflow2'
dis_run = '/home/hpcsw/PBSsubmit/run.dis2'
logger = logging.getLogger('pbs_sub')
# 一些公共函数
def free_cores(queue):
"""
free cores in queue
"""
count = 0
logger.debug(f'getting free cores in {queue}')
try:
raw_nodedata = json.loads(PBSNODES('-aS', '-F', 'json').stdout)['nodes']
raw_nodedata2 = json.loads(PBSNODES('-Saj', '-F', 'json').stdout)['nodes']
except:
logger.error('PBS error, can not get pbs info')
exit(-1)
for k, v in raw_nodedata.items():
if v.get('queue') != queue:
continue
else:
core_free = int(raw_nodedata2[k].get('ncpus f/t').split('/')[0])
count += core_free
logger.debug(f'number of free cores in {queue} is {count}')
return count
# 公共命令行参数
parser = ArgumentParser(description=__doc__,
epilog='if you have any question or suggestion, please contact with thuhak.zhou@nio.com')
parser.add_argument('-n', '--name', help='job name')
parser.add_argument('-l', '--log_level', choices=['error', 'info', 'debug'], default='info', help='logging level')
parser.add_argument('-W', '--wait', metavar='JOB_ID', help='depend on which job')
software = parser.add_subparsers(dest='software', help='supported software')
# powerflow 软件的命令行参数
powerflow = software.add_parser('powerflow', help='powerflow')
powerflow.add_argument('-q', '--queue', choices=["cfd", "cfd2", "cfdbs"], required=True, help='queue of job')
powerflow.add_argument('-o', '--outdir', help='output directory, default is as same as jobfile')
powerflow.add_argument('-c', '--core', type=int, choices=[32, 64, 128, 256, 512, 768, 1024],
help='how many cpu cores you wanna use. default value is half of the free cores in your queue')
powerflow.add_argument('jobfile', help='job file')
powerflow.add_argument('-v', '--version', choices=["4.4d", "5.3b", "5.3c", "5.4b", "5.5a", "5.5c", "5.5c2", "6-2019"],
default="5.5b", help='version of powerflow software')
powerflow.add_argument('--nts', metavar='TIMESTEPS', type=int, help='num timesteps')
pow_group = powerflow.add_mutually_exclusive_group()
powerflow.add_argument('--seed_bondaries', action='store_true', help='set seed_bondaries')
pow_group.add_argument('-r', '--resume', metavar='RESUME_FILE', help='resume file')
pow_group.add_argument('-s', '--seed', metavar='SEED_FILE', help='seed file')
powerflow.add_argument('--mme', action='store_true', help='mme checkpoint at end')
powerflow.add_argument('--full', action='store_true', help='full checkpoint at end')
powerflow.add_argument('--pt', action='store_true', help='ptherm nprocs and ptherm max unmatched ratio')
powerflow.add_argument('--dis', choices=['only', 'full'],
help='run discretize, when you set only, just run dis job, if your set full, powerflow job will be run after dis job is finish')
powerflow.add_argument('--vr', metavar='LEVEL', type=int, help='suppress vr level for discretize job')
# 其他软件参数位置
other = software.add_parser('other')
...
if __name__ == '__main__':
# 公共的
uid = os.getuid()
user = pwd.getpwuid(uid)[0]
email = user + '@nio.com'
args = parser.parse_args()
soft = args.software
waitfor = args.wait
log_level = getattr(logging, args.log_level.upper())
logging.basicConfig(level=log_level, format='%(asctime)s [%(levelname)s]: %(message)s',
datefmt="%Y-%m-%d %H:%M:%S")
for handler in logging.root.handlers:
handler.addFilter(logging.Filter('pbs_sub'))
base_args = ['-m', 'abe', '-M', email]
if waitfor:
if '.' not in waitfor:
waitfor = waitfor + '.' + PBS_SERVER
var_w = f'depend=afterok:{waitfor}'
base_args.extend(['-W', var_w])
# powerflow 处理参数的处理过程
if soft == 'powerflow':
...
# 其他软件处理过程
elif soft == 'other':
...

整个脚本的大致结构为:

  1. 库和全部变量
  2. 一些公共的辅助函数,以及某些软件所使用的函数
  3. 公共的命令行解析器
  4. 所有软件的命令行解析器作为公共命令行解析器的子解析器
  5. 公共的命令行处理过程
  6. 用if-else放置不同软件的处理函数

这个版本执行起来是没有问题的。但是依然会有几个问题:

  • 如果需要支持的软件很多,那么整个代码会变得非常长
  • 这个脚本需要经常改动。某个软件改动如果搞错可能会影响到不相干的软件执行

为了解决这个问题,有两个实现方案:

  • 实现一个公共库,把公共解析器以及公共函数放在里面,每个软件用一个单独的入口脚本,通过引用公共库来实现代码复用
  • 实现一个公共库,把公共解析器以及公共函数放在里面。再将各个软件当作插件装载进来。这个方案需要解决的问题打乱插件的代码顺序,先执行所有parser的注册,最后根据用户输入参数的解析,将请求路由到恰当的插件里

今天要讲的,就是第二种方案的一种实现

__init_subclass__魔术方法介绍

在这个方案中,使用了__initsubclass__这个实现子类的注册和回调。因此我们首先需要了解一下\_init_subclass__的作用和使用方法。

__init_subclass__是python3.6引入的一个新的魔术方法。使用这个魔术方法,可以在子类定义的过程中作为钩子被调用,作用机制和元类相似。最大的区别就是这个魔术方法只在子类构建的时候被调用,父类是不会调用的。实际用起来更为简洁,编程体验也更好。元类说到底,还是有些麻烦的。而且更为重要的是,这个魔术方法可以和元类一起使用,不会产生冲突。经常使用元类编程的同学遇到的问题可能就是无法同时使用两种元类,有了这个魔术方法,代码设计上又增加了很大的灵活性。

现在我们举个具体的例子来说明__init_subclass__是如何使用的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class MetaData(type):
def __init__(cls, clsname, bases, clsdict):
super().__init__(clsname, bases, clsdict) # 可以操作cls
print('meta class')
class A(metaclass=MetaData):
data = 'A'
def __new__(cls, *args, **kwargs):
print('running new')
return super().__new__(cls, *args, **kwargs) #返回实例对象
def __init_subclass__(cls, *args, **kwargs):
print('running init_subclass') #可以操作cls
def __init__(self, *args, **kwargs):
print('running init')

我们先定义一个元类,以及一个使用该元类的类A,输出结果为

1
meta class

从输出的结果可以看到,在定义类A的时候,MetaData被触发

1
2
class B(A):
...

我们用A作为父类创建一个子类B,输出结果为:

1
2
running init_subclass
meta class

从输出的结果来看,调用的顺序是先调用__init_subclass__,再调用元类。

我们再创建一个实例看看:

1
b = B()

输出结果为:

1
2
running new
running init

这个表现符合预期,没什么多说的。我们再用B为父类创建一个孙子类来看看__init_subclass__是否会执行

1
2
class C(B):
...

输出结果为

1
2
running init_subclass
meta class

可以看到,执行结果和B一样,说明__init_subclass__是可以传递的。我们再来试试方法重写

1
2
3
4
class D(A):
def __init_subclass__(cls, *args, **kwargs):
print('rewrite')
super().__init_subclass__(cls, *args, **kwargs)
1
2
running init_subclass
meta class
1
2
class F(D):
...

执行结果为:

1
2
3
rewrite
running init_subclass
meta class

可以看到,__init_subclass__可以像常规的的方法一样实现重写。我们再来试试多重继承:

1
2
3
4
class E:
def __init_subclass__(cls, *args, **kwargs):
print('other init subclass')
super().__init_subclass__(cls, *args, **kwargs)
1
2
class G(A, E):
...

执行结果为:

1
2
running init_subclass
meta class

E中的__init_subclass__并没有执行,但是并没有有冲突错误。

1
2
class H(E, A):
...

执行结果为

1
2
3
other init subclass
running init_subclass
meta class

因为使用了supper(),A中的__init_subclass__得到了执行。看起来,多重继承的表现和普通的方法没什么两样。

经过一系列实验,最后得到的结论是在同时使用metaclass、__initsubclass__, \_new与__init的时候。在类的构建期依次调用__initsubclass__,与metaclass,而在实例构建的时候,依次调用\_new与__init

而__init_subclass__的特性,和普通的方法区别不大。如果想有更多的了解,可以进一步参考PEP 487

利用__init_subclass__实现注册和回调机制

说完了__init_subclass__的用法以后,我们再具体看一下在这个实际案例中,如何利用这个魔术方法的特性,实现代码的解藕。

先上代码:

pbs_sub.py

1
2
3
4
5
6
#!/usr/bin/env python3.6
from software import MainParser
if __name__ == '__main__':
parser = MainParser()
parser.run()

software/__init__.py

1
2
from .mainparser import MainParser
from .powerflow import PowerFlow

software/mainparser.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""
main parser
"""
# author: thuhak.zhou@nio.com
from argparse import ArgumentParser
from weakref import WeakValueDictionary
import logging
from logging.handlers import RotatingFileHandler
import json
import os
from os.path import dirname, join, abspath, isfile, isdir
from string import ascii_letters, digits
import pwd
import time
import sh
__version__ = '2.1.0'
__author__ = 'thuhak.zhou@nio.com'
class classproperty:
def __init__(self, func):
self.func = func
def __get__(self, instance, cls):
if cls is None:
return self
else:
value = self.func(cls)
setattr(cls, self.func.__name__, value)
return value
class MainParser:
"""
provide common args for pbs.
argument parser of subclasses will regist in _all_software variable
the request will be route to the right software subclass
"""
_all_software = WeakValueDictionary()
logger = logging.getLogger('pbs_sub')
PBS_SERVER = 'p-shhq-hpc-pbs-m01'
QSUB = sh.Command('/opt/pbs/bin/qsub')
script_base = join(dirname(abspath(__file__)), 'run_scripts')
def __init__(self):
self.base_args = []
self.parser = ArgumentParser(description=f"This script is used for pbs job submission, version: {__version__}",
epilog=f'if you have any question or suggestion, please contact with {__author__}')
self.parser.add_argument('-n', '--name', help='job name')
self.parser.add_argument('-l', '--log_level', choices=['error', 'info', 'debug'], default='info',
help='logging level')
self.parser.add_argument('-W', '--wait', metavar='JOB_ID', help='depend on which job')
self.parser.add_argument('--free_cores', action='store_true', help='show free cpu cores by queue')
software = self.parser.add_subparsers(dest='software', help='software list')
for soft in self._all_software:
cls = self._all_software[soft]
parser = software.add_parser(cls.__software__, help=f'(script version {cls.__version__})')
cls.add_parser(parser)
@classmethod
def add_parser(cls, parser: ArgumentParser):
"""
abc interface, implement this method in subclass
"""
raise NotImplementedError
@classmethod
def handle(cls, args, base_args) -> list:
"""
abc interface, implement this method in subclass
args: argument args
base_args: all argument provided by main parser
:return all job ids in list
"""
raise NotImplementedError
def __init_subclass__(cls, **kwargs):
"""
regist subclass
"""
super().__init_subclass__(**kwargs)
if not getattr(cls, '__software__') or not getattr(cls, '__version__'):
raise NotImplementedError('you need to set __software__ and __version__ attribute in your subclass')
cls._all_software[cls.__software__] = cls
@classproperty
def default_run(cls):
"""
default run script
"""
run_script = join(cls.script_base, cls.__software__ + '.sh')
if not isfile(run_script):
cls.logger.error(f'you need to put the script {run_script} first')
exit(1)
return run_script
@classproperty
def pbs_nodes_data(cls):
"""
get node data from pbs
"""
cls.logger.debug('getting pbs node info')
try:
PBSNODES = sh.Command('/opt/pbs/bin/pbsnodes')
raw_nodedata = json.loads(PBSNODES('-aS', '-F', 'json').stdout)['nodes']
raw_nodedata2 = json.loads(PBSNODES('-Saj', '-F', 'json').stdout)['nodes']
for k in raw_nodedata.keys():
raw_nodedata[k].update(raw_nodedata2[k])
return raw_nodedata
except Exception as e:
cls.logger.error(f'PBS error, can not get pbs node info, reason: {str(e)}')
exit(-1)
@classproperty
def pbs_job_data(cls):
"""
get pbs job data
"""
cls.logger.debug('getting pbs job info')
try:
QSTAT = sh.Command('/opt/pbs/bin/qstat')
raw_data = sh.grep(QSTAT('-f', '-F', 'json'), '-v', 'Submit_arguments').stdout
job_data = json.loads(raw_data)['Jobs']
return job_data
except Exception as e:
cls.logger.error(f'PBS error, can not get pbs job info, reason: {str(e)}')
exit(-1)
@classmethod
def free_cores(cls, queue: str) -> int:
"""
get free cores in queue
"""
count = 0
cls.logger.debug(f'getting free cores in {queue}')
pbsdata = cls.pbs_nodes_data
for node in pbsdata.values():
if node.get('queue') == queue:
core_free = int(node.get('ncpus f/t').split('/')[0])
count += core_free
cls.logger.debug(f'number of free cores in {queue} is {count}')
return count
@classmethod
def all_free_cores(cls) -> dict:
"""
get all free cores in all queues
"""
from collections import defaultdict
cls.logger.debug('getting free cores')
pbsdata = cls.pbs_nodes_data
ret = defaultdict(int)
for node in pbsdata.values():
queue = node.get('queue')
if queue:
ret[queue] += int(node.get('ncpus f/t').split('/')[0])
return ret
@classmethod
def check_jobid(cls, jid: str) -> bool:
"""
check job id in pbs or not
"""
cls.logger.debug(f'checking job {jid}')
return jid in cls.pbs_job_data
@classmethod
def get_jid_info(cls, jid: str) -> dict:
cls.logger.debug(f'getting job information for {jid}')
try:
QSTAT = sh.Command('/opt/pbs/bin/qstat')
raw_info = sh.grep(QSTAT('-f', '-F', 'json', jid), '-v', 'Submit_arguments').stdout
job_info = json.loads(raw_info)['Jobs'][jid]
return job_info
except json.decoder.JSONDecodeError:
cls.logger.error(f'job {jid} is not json format')
cls.logger.debug(raw_info)
exit(1)
except Exception as e:
cls.logger.error(f'unable to get information from pbs, reason: {str(e)}')
exit(1)
@staticmethod
def replace_id(args: list, jobid: str) -> None:
"""this function will change the state of args"""
try:
w_i = args.index('-W')
args[w_i + 1] = f'depend=afterok:{jobid}'
except ValueError:
args.extend(['-W', f'depend=afterok:{jobid}'])
@staticmethod
def fix_jobname(name: str) -> str:
valid_str = ascii_letters + digits + '_'
ret = name
for c in name:
if c not in valid_str:
ret = ret.replace(c, '_')
return ret
@staticmethod
def get_ncpu(queue: str) -> int:
return 16 if queue == 'cfdbs' else 32
def run(self):
"""
program entry
"""
uid = os.getuid()
user = pwd.getpwuid(uid)[0]
args = self.parser.parse_args()
log_level = getattr(logging, args.log_level.upper())
handlers = [logging.StreamHandler()]
logdir = '/var/log/pbs_sub'
if isdir(logdir):
loghandler = RotatingFileHandler(join(logdir, user + '.log'), maxBytes=10*1024*1024, backupCount=3, encoding='utf-8')
handlers.append(loghandler)
logging.basicConfig(level=log_level, format='%(asctime)s [%(levelname)s]: %(message)s',
datefmt="%Y-%m-%d %H:%M:%S", handlers=handlers)
for handler in logging.root.handlers:
handler.addFilter(logging.Filter('pbs_sub'))
if args.free_cores:
free_cores = self.all_free_cores()
ali = max([len(x) for x in free_cores.keys()]) + 3
for q, c in free_cores.items():
space = ' ' * (ali - len(q))
print(f'{q}:{space}{c}')
return
software = args.software
email = user + '@nio.com'
waitfor = args.wait
base_args = ['-m', 'abe', '-M', email]
if waitfor:
if '.' not in waitfor:
waitfor = waitfor + '.' + self.PBS_SERVER
if not self.check_jobid(waitfor):
self.logger.debug(f'invalid job id {waitfor}')
exit(1)
var_w = f'depend=afterok:{waitfor}'
base_args.extend(['-W', var_w])
if software in self._all_software:
jids = self._all_software[software].handle(args, base_args)
time.sleep(1)
for jid in jids:
job_info = self.get_jid_info(jid)
job_stat = job_info.get('job_state')
info = f'job {jid} is in state {job_stat}'
comment = job_info.get('comment')
if comment:
info += f',comment: {comment}'
self.logger.info(info)
else:
self.parser.print_help()

software/powerflow.py

因为涉及到具体的业务逻辑,因此整个实现就被我省略了,这里只是为了演示整个结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# author: thuhak.zhou@nio.com
import os
import random
from copy import deepcopy
from glob import glob
import re
import sh
from .mainparser import MainParser
class PowerFlow(MainParser):
__software__ = 'powerflow'
__version__ = '2.1.0'
base_dir = '/home/hpcsw/EXA/powerflow'
@classmethod
def add_parser(cls, parser):
parser.add_argument('-q', '--queue', choices=["cfd", "cfd2", "cfdbs"], default="cfdbs", help='queue of job')
parser.add_argument('-c', '--core', type=int, choices=[32, 64, 128, 256, 512, 768, 1024],
help='how many cpu cores you wanna use. default value is half of the free cores in your queue')
parser.add_argument('jobfile', help='job file')
parser.add_argument('-v', '--version',
choices=["4.4d", "5.3b", "5.3c", "5.4b", "5.5a", "5.5b", "5.5c", "5.5c2", "6-2019"],
default="5.5b", help='version of powerflow')
parser.add_argument('--nts', metavar='TIMESTEPS', type=int, help='num timesteps')
parser.add_argument('--seed_bondaries', action='store_true', help='set seed_bondaries')
pow_group = parser.add_mutually_exclusive_group()
pow_group.add_argument('-r', '--resume', metavar='RESUME_FILE', help='resume file')
pow_group.add_argument('-s', '--seed', metavar='SEED_FILE', help='seed file')
parser.add_argument('--mme', action='store_true', help='mme checkpoint at end')
parser.add_argument('--full', action='store_true', help='full checkpoint at end')
parser.add_argument('--pt', action='store_true', help='enable powerTHERM')
parser.add_argument('--dis', choices=['only', 'full'],
help='run discretize, when you set only, just run dis job, if your set full, powerflow job will be run after dis job is finish')
parser.add_argument('--vr', metavar='LEVEL', type=int, help='suppress vr level for discretize job')
parser.add_argument('--postacoustics', choices=['only', 'full'],
help='run powerflow post process, contact with fuchao.wang@nio.com to get more information')
@classmethod
def default_cores(cls, queue):
total_cores = cls.free_cores(queue)
for c in (2048, 1536, 1024, 960, 512):
if total_cores >= c:
return c // 2
return 128
@classmethod
def handle(cls, args, base_args):
jids = []
...
return jids

从整个代码结构上, 我们可以看到。

公共的逻辑被放置到了MainParser中,而powerflow的业务逻辑,则被独立开来,放置到了一个单独的文件内。并通过类的继承,复用了公用的类的方法,变量,以及数据缓存。

整个软件的逻辑如下:

  1. 在MainParser类中,通过__init_subclass__,在类还在创建过程中,就将整个类对象放置到自己的全局字典中。而类创建代码,则是在import的时候执行
  2. 在MainPasrer进行实例初始化的时候,依次调用所有子类的add_parser方法,实现所有插件parser的注册
  3. 最后在实例化的mainparser对象中,通过run方法,解析用户的输入,并路由给对应的插件handle处理方法。拿到对应handle的处理后,在进行下一道工序的处理

所有的MainParser的类中,所有的子类均没有进行实例化,因为并没有必要。唯一实例化的是MainParser本身。

而当你需要添加一个新的计算软件。那么你只需要:

在software中放置一个单独的文件,写一个MainParser的子类.

在这个子类中,需要实现:

  • 定义__software,以及__version类变量。注册插件信息
  • 实现add_parser类方法,提供子命令的入口给mainparser回调
  • 实现handle类方法,提供给mainparser回调
  • 最后通过修改sotware/__init__.py将新的代码引入即可

如果觉得每次都要import一下新的代码很烦,可以定制import的过程做自动加载,但是目前来看,并没有多少必要来做这个事,还是从简的好。

总结

通过使用python元编程,我们往往可以使用极少的代码就可以实现比较复杂的设计模式。这也python是让人非常上瘾的一点。