tools: ynl: support directional enum-model in CLI
[linux-2.6-microblaze.git] / tools / net / ynl / lib / ynl.py
1 # SPDX-License-Identifier: BSD-3-Clause
2
3 import functools
4 import os
5 import random
6 import socket
7 import struct
8 import yaml
9
10 from .nlspec import SpecFamily
11
12 #
13 # Generic Netlink code which should really be in some library, but I can't quickly find one.
14 #
15
16
17 class Netlink:
18     # Netlink socket
19     SOL_NETLINK = 270
20
21     NETLINK_ADD_MEMBERSHIP = 1
22     NETLINK_CAP_ACK = 10
23     NETLINK_EXT_ACK = 11
24
25     # Netlink message
26     NLMSG_ERROR = 2
27     NLMSG_DONE = 3
28
29     NLM_F_REQUEST = 1
30     NLM_F_ACK = 4
31     NLM_F_ROOT = 0x100
32     NLM_F_MATCH = 0x200
33     NLM_F_APPEND = 0x800
34
35     NLM_F_CAPPED = 0x100
36     NLM_F_ACK_TLVS = 0x200
37
38     NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
39
40     NLA_F_NESTED = 0x8000
41     NLA_F_NET_BYTEORDER = 0x4000
42
43     NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
44
45     # Genetlink defines
46     NETLINK_GENERIC = 16
47
48     GENL_ID_CTRL = 0x10
49
50     # nlctrl
51     CTRL_CMD_GETFAMILY = 3
52
53     CTRL_ATTR_FAMILY_ID = 1
54     CTRL_ATTR_FAMILY_NAME = 2
55     CTRL_ATTR_MAXATTR = 5
56     CTRL_ATTR_MCAST_GROUPS = 7
57
58     CTRL_ATTR_MCAST_GRP_NAME = 1
59     CTRL_ATTR_MCAST_GRP_ID = 2
60
61     # Extack types
62     NLMSGERR_ATTR_MSG = 1
63     NLMSGERR_ATTR_OFFS = 2
64     NLMSGERR_ATTR_COOKIE = 3
65     NLMSGERR_ATTR_POLICY = 4
66     NLMSGERR_ATTR_MISS_TYPE = 5
67     NLMSGERR_ATTR_MISS_NEST = 6
68
69
70 class NlAttr:
71     def __init__(self, raw, offset):
72         self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
73         self.type = self._type & ~Netlink.NLA_TYPE_MASK
74         self.payload_len = self._len
75         self.full_len = (self.payload_len + 3) & ~3
76         self.raw = raw[offset + 4:offset + self.payload_len]
77
78     def as_u8(self):
79         return struct.unpack("B", self.raw)[0]
80
81     def as_u16(self):
82         return struct.unpack("H", self.raw)[0]
83
84     def as_u32(self):
85         return struct.unpack("I", self.raw)[0]
86
87     def as_u64(self):
88         return struct.unpack("Q", self.raw)[0]
89
90     def as_strz(self):
91         return self.raw.decode('ascii')[:-1]
92
93     def as_bin(self):
94         return self.raw
95
96     def __repr__(self):
97         return f"[type:{self.type} len:{self._len}] {self.raw}"
98
99
100 class NlAttrs:
101     def __init__(self, msg):
102         self.attrs = []
103
104         offset = 0
105         while offset < len(msg):
106             attr = NlAttr(msg, offset)
107             offset += attr.full_len
108             self.attrs.append(attr)
109
110     def __iter__(self):
111         yield from self.attrs
112
113     def __repr__(self):
114         msg = ''
115         for a in self.attrs:
116             if msg:
117                 msg += '\n'
118             msg += repr(a)
119         return msg
120
121
122 class NlMsg:
123     def __init__(self, msg, offset, attr_space=None):
124         self.hdr = msg[offset:offset + 16]
125
126         self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
127             struct.unpack("IHHII", self.hdr)
128
129         self.raw = msg[offset + 16:offset + self.nl_len]
130
131         self.error = 0
132         self.done = 0
133
134         extack_off = None
135         if self.nl_type == Netlink.NLMSG_ERROR:
136             self.error = struct.unpack("i", self.raw[0:4])[0]
137             self.done = 1
138             extack_off = 20
139         elif self.nl_type == Netlink.NLMSG_DONE:
140             self.done = 1
141             extack_off = 4
142
143         self.extack = None
144         if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
145             self.extack = dict()
146             extack_attrs = NlAttrs(self.raw[extack_off:])
147             for extack in extack_attrs:
148                 if extack.type == Netlink.NLMSGERR_ATTR_MSG:
149                     self.extack['msg'] = extack.as_strz()
150                 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
151                     self.extack['miss-type'] = extack.as_u32()
152                 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
153                     self.extack['miss-nest'] = extack.as_u32()
154                 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
155                     self.extack['bad-attr-offs'] = extack.as_u32()
156                 else:
157                     if 'unknown' not in self.extack:
158                         self.extack['unknown'] = []
159                     self.extack['unknown'].append(extack)
160
161             if attr_space:
162                 # We don't have the ability to parse nests yet, so only do global
163                 if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
164                     miss_type = self.extack['miss-type']
165                     if miss_type in attr_space.attrs_by_val:
166                         spec = attr_space.attrs_by_val[miss_type]
167                         desc = spec['name']
168                         if 'doc' in spec:
169                             desc += f" ({spec['doc']})"
170                         self.extack['miss-type'] = desc
171
172     def __repr__(self):
173         msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
174         if self.error:
175             msg += '\terror: ' + str(self.error)
176         if self.extack:
177             msg += '\textack: ' + repr(self.extack)
178         return msg
179
180
181 class NlMsgs:
182     def __init__(self, data, attr_space=None):
183         self.msgs = []
184
185         offset = 0
186         while offset < len(data):
187             msg = NlMsg(data, offset, attr_space=attr_space)
188             offset += msg.nl_len
189             self.msgs.append(msg)
190
191     def __iter__(self):
192         yield from self.msgs
193
194
195 genl_family_name_to_id = None
196
197
198 def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
199     # we prepend length in _genl_msg_finalize()
200     if seq is None:
201         seq = random.randint(1, 1024)
202     nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
203     genlmsg = struct.pack("bbH", genl_cmd, genl_version, 0)
204     return nlmsg + genlmsg
205
206
207 def _genl_msg_finalize(msg):
208     return struct.pack("I", len(msg) + 4) + msg
209
210
211 def _genl_load_families():
212     with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
213         sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
214
215         msg = _genl_msg(Netlink.GENL_ID_CTRL,
216                         Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
217                         Netlink.CTRL_CMD_GETFAMILY, 1)
218         msg = _genl_msg_finalize(msg)
219
220         sock.send(msg, 0)
221
222         global genl_family_name_to_id
223         genl_family_name_to_id = dict()
224
225         while True:
226             reply = sock.recv(128 * 1024)
227             nms = NlMsgs(reply)
228             for nl_msg in nms:
229                 if nl_msg.error:
230                     print("Netlink error:", nl_msg.error)
231                     return
232                 if nl_msg.done:
233                     return
234
235                 gm = GenlMsg(nl_msg)
236                 fam = dict()
237                 for attr in gm.raw_attrs:
238                     if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
239                         fam['id'] = attr.as_u16()
240                     elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
241                         fam['name'] = attr.as_strz()
242                     elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
243                         fam['maxattr'] = attr.as_u32()
244                     elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
245                         fam['mcast'] = dict()
246                         for entry in NlAttrs(attr.raw):
247                             mcast_name = None
248                             mcast_id = None
249                             for entry_attr in NlAttrs(entry.raw):
250                                 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
251                                     mcast_name = entry_attr.as_strz()
252                                 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
253                                     mcast_id = entry_attr.as_u32()
254                             if mcast_name and mcast_id is not None:
255                                 fam['mcast'][mcast_name] = mcast_id
256                 if 'name' in fam and 'id' in fam:
257                     genl_family_name_to_id[fam['name']] = fam
258
259
260 class GenlMsg:
261     def __init__(self, nl_msg):
262         self.nl = nl_msg
263
264         self.hdr = nl_msg.raw[0:4]
265         self.raw = nl_msg.raw[4:]
266
267         self.genl_cmd, self.genl_version, _ = struct.unpack("bbH", self.hdr)
268
269         self.raw_attrs = NlAttrs(self.raw)
270
271     def __repr__(self):
272         msg = repr(self.nl)
273         msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
274         for a in self.raw_attrs:
275             msg += '\t\t' + repr(a) + '\n'
276         return msg
277
278
279 class GenlFamily:
280     def __init__(self, family_name):
281         self.family_name = family_name
282
283         global genl_family_name_to_id
284         if genl_family_name_to_id is None:
285             _genl_load_families()
286
287         self.genl_family = genl_family_name_to_id[family_name]
288         self.family_id = genl_family_name_to_id[family_name]['id']
289
290
291 #
292 # YNL implementation details.
293 #
294
295
296 class YnlFamily(SpecFamily):
297     def __init__(self, def_path, schema=None):
298         super().__init__(def_path, schema)
299
300         self.include_raw = False
301
302         self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC)
303         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
304         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
305
306         self._types = dict()
307
308         for elem in self.yaml.get('definitions', []):
309             self._types[elem['name']] = elem
310
311         self.async_msg_ids = set()
312         self.async_msg_queue = []
313
314         for msg in self.msgs.values():
315             if msg.is_async:
316                 self.async_msg_ids.add(msg.rsp_value)
317
318         for op_name, op in self.ops.items():
319             bound_f = functools.partial(self._op, op_name)
320             setattr(self, op.ident_name, bound_f)
321
322         self.family = GenlFamily(self.yaml['name'])
323
324     def ntf_subscribe(self, mcast_name):
325         if mcast_name not in self.family.genl_family['mcast']:
326             raise Exception(f'Multicast group "{mcast_name}" not present in the family')
327
328         self.sock.bind((0, 0))
329         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
330                              self.family.genl_family['mcast'][mcast_name])
331
332     def _add_attr(self, space, name, value):
333         attr = self.attr_sets[space][name]
334         nl_type = attr.value
335         if attr["type"] == 'nest':
336             nl_type |= Netlink.NLA_F_NESTED
337             attr_payload = b''
338             for subname, subvalue in value.items():
339                 attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
340         elif attr["type"] == 'flag':
341             attr_payload = b''
342         elif attr["type"] == 'u32':
343             attr_payload = struct.pack("I", int(value))
344         elif attr["type"] == 'string':
345             attr_payload = str(value).encode('ascii') + b'\x00'
346         elif attr["type"] == 'binary':
347             attr_payload = value
348         else:
349             raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
350
351         pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
352         return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
353
354     def _decode_enum(self, rsp, attr_spec):
355         raw = rsp[attr_spec['name']]
356         enum = self._types[attr_spec['enum']]
357         i = attr_spec.get('value-start', 0)
358         if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']:
359             value = set()
360             while raw:
361                 if raw & 1:
362                     value.add(enum['entries'][i])
363                 raw >>= 1
364                 i += 1
365         else:
366             value = enum['entries'][raw - i]
367         rsp[attr_spec['name']] = value
368
369     def _decode(self, attrs, space):
370         attr_space = self.attr_sets[space]
371         rsp = dict()
372         for attr in attrs:
373             attr_spec = attr_space.attrs_by_val[attr.type]
374             if attr_spec["type"] == 'nest':
375                 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
376                 rsp[attr_spec['name']] = subdict
377             elif attr_spec['type'] == 'u8':
378                 rsp[attr_spec['name']] = attr.as_u8()
379             elif attr_spec['type'] == 'u32':
380                 rsp[attr_spec['name']] = attr.as_u32()
381             elif attr_spec['type'] == 'u64':
382                 rsp[attr_spec['name']] = attr.as_u64()
383             elif attr_spec["type"] == 'string':
384                 rsp[attr_spec['name']] = attr.as_strz()
385             elif attr_spec["type"] == 'binary':
386                 rsp[attr_spec['name']] = attr.as_bin()
387             elif attr_spec["type"] == 'flag':
388                 rsp[attr_spec['name']] = True
389             else:
390                 raise Exception(f'Unknown {attr.type} {attr_spec["name"]} {attr_spec["type"]}')
391
392             if 'enum' in attr_spec:
393                 self._decode_enum(rsp, attr_spec)
394         return rsp
395
396     def handle_ntf(self, nl_msg, genl_msg):
397         msg = dict()
398         if self.include_raw:
399             msg['nlmsg'] = nl_msg
400             msg['genlmsg'] = genl_msg
401         op = self.rsp_by_value[genl_msg.genl_cmd]
402         msg['name'] = op['name']
403         msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name)
404         self.async_msg_queue.append(msg)
405
406     def check_ntf(self):
407         while True:
408             try:
409                 reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
410             except BlockingIOError:
411                 return
412
413             nms = NlMsgs(reply)
414             for nl_msg in nms:
415                 if nl_msg.error:
416                     print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
417                     print(nl_msg)
418                     continue
419                 if nl_msg.done:
420                     print("Netlink done while checking for ntf!?")
421                     continue
422
423                 gm = GenlMsg(nl_msg)
424                 if gm.genl_cmd not in self.async_msg_ids:
425                     print("Unexpected msg id done while checking for ntf", gm)
426                     continue
427
428                 self.handle_ntf(nl_msg, gm)
429
430     def _op(self, method, vals, dump=False):
431         op = self.ops[method]
432
433         nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
434         if dump:
435             nl_flags |= Netlink.NLM_F_DUMP
436
437         req_seq = random.randint(1024, 65535)
438         msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq)
439         for name, value in vals.items():
440             msg += self._add_attr(op.attr_set.name, name, value)
441         msg = _genl_msg_finalize(msg)
442
443         self.sock.send(msg, 0)
444
445         done = False
446         rsp = []
447         while not done:
448             reply = self.sock.recv(128 * 1024)
449             nms = NlMsgs(reply, attr_space=op.attr_set)
450             for nl_msg in nms:
451                 if nl_msg.error:
452                     print("Netlink error:", os.strerror(-nl_msg.error))
453                     print(nl_msg)
454                     return
455                 if nl_msg.done:
456                     done = True
457                     break
458
459                 gm = GenlMsg(nl_msg)
460                 # Check if this is a reply to our request
461                 if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value:
462                     if gm.genl_cmd in self.async_msg_ids:
463                         self.handle_ntf(nl_msg, gm)
464                         continue
465                     else:
466                         print('Unexpected message: ' + repr(gm))
467                         continue
468
469                 rsp.append(self._decode(gm.raw_attrs, op.attr_set.name))
470
471         if not rsp:
472             return None
473         if not dump and len(rsp) == 1:
474             return rsp[0]
475         return rsp