[PATCH v3 6/7] babel: Refactor TLV parsing code for easier reuse

Toke Høiland-Jørgensen toke at toke.dk
Tue Nov 24 16:21:53 CET 2020


From: Toke Høiland-Jørgensen <toke at toke.dk>

In preparation for adding authentication checks, refactor the TLV walking
code so it can be reused for a separate pass of the packet for
authentication checks.

Signed-off-by: Toke Høiland-Jørgensen <toke at toke.dk>
---
 proto/babel/packets.c |  166 +++++++++++++++++++++++++++++++------------------
 1 file changed, 104 insertions(+), 62 deletions(-)

diff --git a/proto/babel/packets.c b/proto/babel/packets.c
index 415ac3f9c..997ef8c22 100644
--- a/proto/babel/packets.c
+++ b/proto/babel/packets.c
@@ -120,8 +120,19 @@ struct babel_subtlv_source_prefix {
 #define BABEL_UF_DEF_PREFIX	0x80
 #define BABEL_UF_ROUTER_ID	0x40
 
+struct babel_parse_state;
+struct babel_write_state;
+
+struct babel_tlv_data {
+  u8 min_length;
+  int (*read_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_parse_state *state);
+  uint (*write_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_write_state *state, uint max_len);
+  void (*handle_tlv)(union babel_msg *m, struct babel_iface *ifa);
+};
 
 struct babel_parse_state {
+  const struct babel_tlv_data* (*get_tlv_data)(u8 type);
+  const struct babel_tlv_data* (*get_subtlv_data)(u8 type);
   struct babel_proto *proto;
   struct babel_iface *ifa;
   ip_addr saddr;
@@ -167,6 +178,33 @@ struct babel_write_state {
 
 #define NET_SIZE(n) BYTES(net_pxlen(n))
 
+
+/* Helper macros to loop over a series of TLVs.
+ * @start pointer to first TLV
+ * @end   byte * pointer to TLV stream end
+ * @tlv   struct babel_tlv pointer used as iterator
+ */
+#define WALK_TLVS(start, end, tlv, saddr, ifname)                       \
+  for (tlv = (void *)start;						\
+       (byte *)tlv < end;						\
+       tlv = NEXT_TLV(tlv))						\
+  {									\
+    byte *loop_pos;							\
+    /* Ugly special case */						\
+    if (tlv->type == BABEL_TLV_PAD1)					\
+      continue;                                                         \
+									\
+    /* The end of the common TLV header */				\
+    loop_pos = (byte *)tlv + sizeof(struct babel_tlv);			\
+    if ((loop_pos > end) || (loop_pos + tlv->length > end))             \
+    {                                                                   \
+      LOG_PKT("Bad TLV from %I via %s type %d pos %d - framing error",  \
+	      saddr, ifname, tlv->type, (byte *)tlv - (byte *)start);   \
+      goto frame_err;							\
+    }
+
+#define WALK_TLVS_END }
+
 static inline uint
 bytes_equal(u8 *b1, u8 *b2, uint maxlen)
 {
@@ -255,13 +293,6 @@ static uint babel_write_route_request(struct babel_tlv *hdr, union babel_msg *ms
 static uint babel_write_seqno_request(struct babel_tlv *hdr, union babel_msg *msg, struct babel_write_state *state, uint max_len);
 static int babel_write_source_prefix(struct babel_tlv *hdr, net_addr *net, uint max_len);
 
-struct babel_tlv_data {
-  u8 min_length;
-  int (*read_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_parse_state *state);
-  uint (*write_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_write_state *state, uint max_len);
-  void (*handle_tlv)(union babel_msg *m, struct babel_iface *ifa);
-};
-
 static const struct babel_tlv_data tlv_data[BABEL_TLV_MAX] = {
   [BABEL_TLV_ACK_REQ] = {
     sizeof(struct babel_tlv_ack_req),
@@ -319,6 +350,30 @@ static const struct babel_tlv_data tlv_data[BABEL_TLV_MAX] = {
   },
 };
 
+static const struct babel_tlv_data *get_packet_tlv_data(u8 type)
+{
+  return type < sizeof(tlv_data) / sizeof(*tlv_data) ? &tlv_data[type] : NULL;
+}
+
+static const struct babel_tlv_data source_prefix_tlv_data = {
+  sizeof(struct babel_subtlv_source_prefix),
+  babel_read_source_prefix,
+  NULL,
+  NULL
+};
+
+static const struct babel_tlv_data *get_packet_subtlv_data(u8 type)
+{
+  switch(type)
+  {
+  case BABEL_SUBTLV_SOURCE_PREFIX:
+    return &source_prefix_tlv_data;
+
+  default:
+    return NULL;
+  }
+}
+
 static int
 babel_read_ack_req(struct babel_tlv *hdr, union babel_msg *m,
 		   struct babel_parse_state *state)
@@ -1083,69 +1138,67 @@ babel_write_source_prefix(struct babel_tlv *hdr, net_addr *n, uint max_len)
   return len;
 }
 
-
 static inline int
 babel_read_subtlvs(struct babel_tlv *hdr,
 		   union babel_msg *msg,
 		   struct babel_parse_state *state)
 {
+  const struct babel_tlv_data *tlv_data;
+  struct babel_proto *p = state->proto;
   struct babel_tlv *tlv;
-  byte *pos, *end = (byte *) hdr + TLV_LENGTH(hdr);
+  byte *end = (byte *) hdr + TLV_LENGTH(hdr);
   int res;
 
-  for (tlv = (void *) hdr + state->current_tlv_endpos;
-       (byte *) tlv < end;
-       tlv = NEXT_TLV(tlv))
+  WALK_TLVS(hdr + state->current_tlv_endpos, end, tlv,
+            state->saddr, state->ifa->ifname)
   {
-    /* Ugly special case */
-    if (tlv->type == BABEL_TLV_PAD1)
+    if (tlv->type == BABEL_SUBTLV_PADN)
       continue;
 
-    /* The end of the common TLV header */
-    pos = (byte *)tlv + sizeof(struct babel_tlv);
-    if ((pos > end) || (pos + tlv->length > end))
-      return PARSE_ERROR;
-
-    /*
-     * The subtlv type space is non-contiguous (due to the mandatory bit), so
-     * use a switch for dispatch instead of the mapping array we use for TLVs
-     */
-    switch (tlv->type)
+    if (!state->get_subtlv_data ||
+        !(tlv_data = state->get_subtlv_data(tlv->type)) ||
+        !tlv_data->read_tlv)
     {
-    case BABEL_SUBTLV_SOURCE_PREFIX:
-      res = babel_read_source_prefix(tlv, msg, state);
-      if (res != PARSE_SUCCESS)
-	return res;
-      break;
-
-    case BABEL_SUBTLV_PADN:
-    default:
       /* Unknown mandatory subtlv; PARSE_IGNORE ignores the whole TLV */
       if (tlv->type >= 128)
-	return PARSE_IGNORE;
-      break;
+        return PARSE_IGNORE;
+      continue;
     }
+
+    res = tlv_data->read_tlv(tlv, msg, state);
+    if (res != PARSE_SUCCESS)
+      return res;
   }
+  WALK_TLVS_END;
 
   return PARSE_SUCCESS;
+
+ frame_err:
+  return PARSE_ERROR;
 }
 
-static inline int
+static int
 babel_read_tlv(struct babel_tlv *hdr,
                union babel_msg *msg,
                struct babel_parse_state *state)
 {
+  const struct babel_tlv_data *tlv_data;
+
   if ((hdr->type <= BABEL_TLV_PADN) ||
-      (hdr->type >= BABEL_TLV_MAX) ||
-      !tlv_data[hdr->type].read_tlv)
+      (hdr->type >= BABEL_TLV_MAX))
     return PARSE_IGNORE;
 
-  if (TLV_LENGTH(hdr) < tlv_data[hdr->type].min_length)
+  tlv_data = state->get_tlv_data(hdr->type);
+
+  if (!tlv_data || !tlv_data->read_tlv)
+    return PARSE_IGNORE;
+
+  if (TLV_LENGTH(hdr) < tlv_data->min_length)
     return PARSE_ERROR;
 
-  state->current_tlv_endpos = tlv_data[hdr->type].min_length;
+  state->current_tlv_endpos = tlv_data->min_length;
 
-  int res = tlv_data[hdr->type].read_tlv(hdr, msg, state);
+  int res = tlv_data->read_tlv(hdr, msg, state);
   if (res != PARSE_SUCCESS)
     return res;
 
@@ -1337,15 +1390,16 @@ babel_process_packet(struct babel_pkt_header *pkt, int len,
   int res;
 
   int plen = sizeof(struct babel_pkt_header) + get_u16(&pkt->length);
-  byte *pos;
   byte *end = (byte *)pkt + plen;
 
   struct babel_parse_state state = {
-    .proto	  = p,
-    .ifa	  = ifa,
-    .saddr	  = saddr,
-    .next_hop_ip6 = saddr,
-    .sadr_enabled = babel_sadr_enabled(p),
+    .get_tlv_data    = &get_packet_tlv_data,
+    .get_subtlv_data = &get_packet_subtlv_data,
+    .proto           = p,
+    .ifa             = ifa,
+    .saddr           = saddr,
+    .next_hop_ip6    = saddr,
+    .sadr_enabled    = babel_sadr_enabled(p),
   };
 
   if ((pkt->magic != BABEL_MAGIC) || (pkt->version != BABEL_VERSION))
@@ -1369,23 +1423,8 @@ babel_process_packet(struct babel_pkt_header *pkt, int len,
 
   /* First pass through the packet TLV by TLV, parsing each into internal data
      structures. */
-  for (tlv = FIRST_TLV(pkt);
-       (byte *)tlv < end;
-       tlv = NEXT_TLV(tlv))
+  WALK_TLVS(FIRST_TLV(pkt), end, tlv, saddr, ifa->iface->name)
   {
-    /* Ugly special case */
-    if (tlv->type == BABEL_TLV_PAD1)
-      continue;
-
-    /* The end of the common TLV header */
-    pos = (byte *)tlv + sizeof(struct babel_tlv);
-    if ((pos > end) || (pos + tlv->length > end))
-    {
-      LOG_PKT("Bad TLV from %I via %s type %d pos %d - framing error",
-	      saddr, ifa->iface->name, tlv->type, (byte *)tlv - (byte *)pkt);
-      break;
-    }
-
     msg = sl_allocz(p->msg_slab);
     res = babel_read_tlv(tlv, &msg->msg, &state);
     if (res == PARSE_SUCCESS)
@@ -1405,6 +1444,9 @@ babel_process_packet(struct babel_pkt_header *pkt, int len,
       break;
     }
   }
+  WALK_TLVS_END;
+
+frame_err:
 
   /* Parsing done, handle all parsed TLVs */
   WALK_LIST_FIRST(msg, msgs)



More information about the Bird-users mailing list