Source code for ebpfcat.ebpf

# ebpfcat, A Python-based EBPF generator and EtherCAT master
# Copyright (C) 2021 Martin Teichmann <martin.teichmann@gmail.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

"""\
The :mod:`!ebpfcat.ebpf` module contains the core ebpf code generation
======================================================================
"""

__all__ = ["EBPF", "LocalVar", "Member", "Structure", "prandom", "ktime"]

import os
from abc import ABC, abstractmethod
from collections import namedtuple
from contextlib import contextmanager, ExitStack
from operator import index
from struct import pack, pack_into, unpack, unpack_from, calcsize
from enum import Enum

from . import bpf
from .util import sub

Instruction = namedtuple("Instruction",
                         ["opcode", "dst", "src", "off", "imm"])


class FuncId(Enum):
    unspec = 0
    map_lookup_elem = 1
    map_update_elem = 2
    map_delete_elem = 3
    probe_read = 4
    ktime_get_ns = 5
    trace_printk = 6
    get_prandom_u32 = 7
    get_smp_processor_id = 8
    skb_store_bytes = 9
    l3_csum_replace = 10
    l4_csum_replace = 11
    tail_call = 12
    clone_redirect = 13
    get_current_pid_tgid = 14
    get_current_uid_gid = 15
    get_current_comm = 16
    get_cgroup_classid = 17
    skb_vlan_push = 18
    skb_vlan_pop = 19
    skb_get_tunnel_key = 20
    skb_set_tunnel_key = 21
    perf_event_read = 22
    redirect = 23
    get_route_realm = 24
    perf_event_output = 25
    skb_load_bytes = 26
    get_stackid = 27
    csum_diff = 28
    skb_get_tunnel_opt = 29
    skb_set_tunnel_opt = 30
    skb_change_proto = 31
    skb_change_type = 32
    skb_under_cgroup = 33
    get_hash_recalc = 34
    get_current_task = 35
    probe_write_user = 36
    current_task_under_cgroup = 37
    skb_change_tail = 38
    skb_pull_data = 39
    csum_update = 40
    set_hash_invalid = 41
    get_numa_node_id = 42
    skb_change_head = 43
    xdp_adjust_head = 44
    probe_read_str = 45
    get_socket_cookie = 46
    get_socket_uid = 47
    set_hash = 48
    setsockopt = 49
    skb_adjust_room = 50
    redirect_map = 51
    sk_redirect_map = 52
    sock_map_update = 53
    xdp_adjust_meta = 54
    perf_event_read_value = 55
    perf_prog_read_value = 56
    getsockopt = 57
    override_return = 58
    sock_ops_cb_flags_set = 59
    msg_redirect_map = 60
    msg_apply_bytes = 61
    msg_cork_bytes = 62
    msg_pull_data = 63
    bind = 64
    xdp_adjust_tail = 65
    skb_get_xfrm_state = 66
    get_stack = 67
    skb_load_bytes_relative = 68
    fib_lookup = 69
    sock_hash_update = 70
    msg_redirect_hash = 71
    sk_redirect_hash = 72
    lwt_push_encap = 73
    lwt_seg6_store_bytes = 74
    lwt_seg6_adjust_srh = 75
    lwt_seg6_action = 76
    rc_repeat = 77
    rc_keydown = 78
    skb_cgroup_id = 79
    get_current_cgroup_id = 80
    get_local_storage = 81
    sk_select_reuseport = 82
    skb_ancestor_cgroup_id = 83
    sk_lookup_tcp = 84
    sk_lookup_udp = 85
    sk_release = 86
    map_push_elem = 87
    map_pop_elem = 88
    map_peek_elem = 89
    msg_push_data = 90
    msg_pop_data = 91
    rc_pointer_rel = 92
    spin_lock = 93
    spin_unlock = 94
    sk_fullsock = 95
    tcp_sock = 96
    skb_ecn_set_ce = 97
    get_listener_sock = 98
    skc_lookup_tcp = 99
    tcp_check_syncookie = 100
    sysctl_get_name = 101
    sysctl_get_current_value = 102
    sysctl_get_new_value = 103
    sysctl_set_new_value = 104
    strtol = 105
    strtoul = 106
    sk_storage_get = 107
    sk_storage_delete = 108
    send_signal = 109
    tcp_gen_syncookie = 110
    skb_output = 111
    probe_read_user = 112
    probe_read_kernel = 113
    probe_read_user_str = 114
    probe_read_kernel_str = 115
    tcp_send_ack = 116
    send_signal_thread = 117
    jiffies64 = 118
    read_branch_records = 119
    get_ns_current_pid_tgid = 120
    xdp_output = 121
    get_netns_cookie = 122
    get_current_ancestor_cgroup_id = 123
    sk_assign = 124
    ktime_get_boot_ns = 125
    seq_printf = 126
    seq_write = 127
    sk_cgroup_id = 128
    sk_ancestor_cgroup_id = 129
    ringbuf_output = 130
    ringbuf_reserve = 131
    ringbuf_submit = 132
    ringbuf_discard = 133
    ringbuf_query = 134
    csum_level = 135
    skc_to_tcp6_sock = 136
    skc_to_tcp_sock = 137
    skc_to_tcp_timewait_sock = 138
    skc_to_tcp_request_sock = 139
    skc_to_udp6_sock = 140
    get_task_stack = 141
    load_hdr_opt = 142
    store_hdr_opt = 143
    reserve_hdr_opt = 144
    inode_storage_get = 145
    inode_storage_delete = 146
    d_path = 147
    copy_from_user = 148
    snprintf_btf = 149
    seq_printf_btf = 150
    skb_cgroup_classid = 151
    redirect_neigh = 152
    per_cpu_ptr = 153
    this_cpu_ptr = 154
    redirect_peer = 155

class Opcode(Enum):
    ADD = 4
    SUB = 0x14
    MUL = 0x24
    DIV = 0x34
    OR = 0x44
    AND = 0x54
    LSH = 0x64
    RSH = 0x74
    NEG = 0x84
    MOD = 0x94
    XOR = 0xa4
    MOV = 0xb4
    ARSH = 0xc4

    JMP = 5
    JEQ = 0x15
    JGT = 0x25
    JGE = 0x35
    JSET = 0x45
    JNE = 0x55
    JSGT = 0x65
    JSGE = 0x75
    JLT = 0xa5
    JLE = 0xb5
    JSLT = 0xc5
    JSLE = 0xd5
    SHORT = 1

    CALL = 0x85
    EXIT = 0x95

    REG = 8
    LONG = 3

    W = 0
    H = 8
    B = 0x10
    DW = 0x18

    LD = 0x61
    ST = 0x62
    STX = 0x63
    XADD = 0xc3
    LE = 0xd4
    BE = 0xdc

    def __mul__(self, value):
        if value:
            return OpcodeFlags({self})
        else:
            return OpcodeFlags(set())

    def __add__(self, value):
        return OpcodeFlags({self}) + value

    def __repr__(self):
        return 'O.' + self.name

class OpcodeFlags:
    def __init__(self, opcodes):
        self.opcodes = opcodes

    @property
    def value(self):
        return sum(op.value for op in self.opcodes)

    def __add__(self, value):
        if isinstance(value, Opcode):
            return OpcodeFlags(self.opcodes | {value})
        else:
            return OpcodeFlags(self.opcodes | value.opcodes)

    def __repr__(self):
        return "+".join(repr(op) for op in self.opcodes)

    def __eq__(self, value):
        return self.value == value.value


class AssembleError(Exception):
    pass


def comparison(uposop, unegop, sposop, snegop):
    def ret(self, value):
        value = ensure_expression(self.ebpf, value)
        myself = self
        if self.fixed != value.fixed:
            if self.fixed:
                value *= self.FIXED_BASE
            else:
                myself *= self.FIXED_BASE

        if self.signed or value.signed:
            return SimpleComparison(self.ebpf, myself, value, (sposop, snegop))
        else:
            return SimpleComparison(self.ebpf, myself, value, (uposop, unegop))
    return ret


class Elser:
    def __init__(self, comp):
        self.comp = comp

    def __enter__(self):
        return self.comp.Else()

    def __exit__(self, exc_type, exc, tb):
        self.comp.__exit__(exc_type, exc, tb)


class Comparison(ABC):
    """Base class for all logical operations"""

    def __init__(self, ebpf):
        self.ebpf = ebpf
        self.else_origin = None

    def __enter__(self):
        if self.else_origin is None:
            self.compare(True)
        return Elser(self)

    def __exit__(self, exc_type, exc, tb):
        if self.else_origin is None:
            self.target()
            return
        assert self.ebpf.opcodes[self.else_origin] is None
        self.ebpf.opcodes[self.else_origin] = Instruction(
                Opcode.JMP, 0, 0,
                len(self.ebpf.opcodes) - self.else_origin - 1, 0)
        self.ebpf.owners, self.owners = \
                self.ebpf.owners & self.owners, self.ebpf.owners

    @abstractmethod
    def compare(self, negative):
        """issue the actual comparison code

        the issued code should either jump to a position later
        determined by the `target` method, or just fall through.
        If `negative` is true, the code should jump away if the
        condition in question is false, and vice versa.
        """
        raise NotImplementedError

    @abstractmethod
    def target(self, retarget=False):
        """modify the already issued jumps to jump here

        you may re-set the target a second time, but then `retarget` needs
        to be true.
        """
        raise NotImplementedError

    def Else(self):
        self.else_origin = len(self.ebpf.opcodes)
        self.ebpf.opcodes.append(None)
        self.target(True)
        return self

    def __and__(self, value):
        return AndOrComparison(self.ebpf, self, value, True)

    def __or__(self, value):
        return AndOrComparison(self.ebpf, self, value, False)

    def __invert__(self):
        return InvertComparison(self.ebpf, self)

    def __bool__(self):
        raise AssembleError("Use with statement for comparisons")


class SimpleComparison(Comparison):
    """A simple numerical comparison, results in a jump instruction"""

    def __init__(self, ebpf, left, right, opcode):
        super().__init__(ebpf)
        self.left = left
        self.right = right
        self.opcode = opcode

    def compare(self, negative):
        with self.left.calculate(None, None) as (self.dst, l_long):
            with ExitStack() as exitStack:
                if not self.right.small_constant:
                    self.src, r_long = exitStack.enter_context(
                        self.right.calculate(
                            None, self.left.signed and l_long or None))
                else:
                    r_long = False
                self.opcode = self.opcode[negative]
                if self.left.signed or self.right.signed:
                    if not l_long and not r_long:
                        self.opcode += Opcode.SHORT
                    elif not l_long and r_long:
                        self.ebpf.r[self.dst] <<= 32
                        self.ebpf.sr[self.dst] >>= 32
                self.origin = len(self.ebpf.opcodes)
                self.ebpf.opcodes.append(None)
        self.owners = self.ebpf.owners.copy()

    def target(self, retarget=False):
        assert retarget or self.ebpf.opcodes[self.origin] is None
        if self.opcode == Opcode.JMP:
            inst = Instruction(Opcode.JMP, 0, 0,
                               len(self.ebpf.opcodes) - self.origin - 1, 0)
        elif self.right.small_constant:
            inst = Instruction(
                self.opcode, self.dst, 0,
                len(self.ebpf.opcodes) - self.origin - 1,
                int(self.right.value))
        else:
            inst = Instruction(
                self.opcode + Opcode.REG, self.dst, self.src,
                len(self.ebpf.opcodes) - self.origin - 1, 0)
        self.ebpf.opcodes[self.origin] = inst
        if not retarget:
            self.ebpf.owners, self.owners = \
                    self.ebpf.owners & self.owners, self.ebpf.owners


class AndOrComparison(Comparison):
    def __init__(self, ebpf, left, right, is_and):
        super().__init__(ebpf)
        self.left = left
        self.right = right
        self.is_and = is_and

    def compare(self, negative):
        self.negative = negative
        self.left.compare(self.is_and)
        self.right.compare(negative)
        self.origin = len(self.ebpf.opcodes)
        if self.is_and != negative:
            self.left.target()
        self.owners = self.ebpf.owners.copy()

    def target(self, retarget=False):
        if self.is_and == self.negative:
            self.left.target(retarget)
        self.right.target(retarget)


class InvertComparison(Comparison):
    def __init__(self, ebpf, value):
        super().__init__(ebpf)
        self.value = value

    def compare(self, negative):
        self.value.compare(not negative)
        self.owners = self.value.owners

    def target(self, retarget=False):
        self.value.target(retarget)


def ensure_expression(ebpf, value):
    if isinstance(value, Expression):
        return value
    else:
        return Constant(ebpf, value)


def fmtsize(fmt):
    if fmt == "x":
        return 8
    elif isinstance(fmt, str):
        return calcsize(fmt)
    else:  # bit access
        return 1


class Expression:
    """the base class for all numerical expressions"""

    FIXED_BASE = 100000
    small_constant = False

    def _binary(self, value, opcode):
        value = ensure_expression(self.ebpf, value)
        return Binary(self.ebpf, self, value, opcode,
                      self.signed or value.signed, False)

    __ror__ = __or__ = lambda self, value: self._binary(value, Opcode.OR)
    __lshift__ = lambda self, value: self._binary(value, Opcode.LSH)
    __rlshift__ = lambda self, value: Constant(self.ebpf, value) << self
    __rxor__ = __xor__ = lambda self, value: self._binary(value, Opcode.XOR)

    __gt__ = comparison(Opcode.JGT, Opcode.JLE, Opcode.JSGT, Opcode.JSLE)
    __ge__ = comparison(Opcode.JGE, Opcode.JLT, Opcode.JSGE, Opcode.JSLT)
    __lt__ = comparison(Opcode.JLT, Opcode.JGE, Opcode.JSLT, Opcode.JSGE)
    __le__ = comparison(Opcode.JLE, Opcode.JGT, Opcode.JSLE, Opcode.JSGT)
    __ne__ = comparison(Opcode.JNE, Opcode.JEQ, Opcode.JNE, Opcode.JEQ)

    def _sum(self, value, opcode):
        value = ensure_expression(self.ebpf, value)
        myself = self
        if self.fixed != value.fixed:
            if self.fixed:
                value *= self.FIXED_BASE
            else:
                myself *= self.FIXED_BASE

        return Binary(self.ebpf, myself, value, opcode,
                      self.signed or value.signed, self.fixed or value.fixed)

    __radd__ = __add__ = lambda self, value: self._sum(value, Opcode.ADD)
    __sub__ = lambda self, value: self._sum(value, Opcode.SUB)
    __rsub__ = lambda self, value: Constant(self.ebpf, value) - self
    __mod__ = lambda self, value: self._sum(value, Opcode.MOD)
    __rmod__ = lambda self, value: Constant(self.ebpf, value) % self

    def __mul__(self, value):
        value = ensure_expression(self.ebpf, value)
        ret = Binary(self.ebpf, self, value, Opcode.MUL,
                     self.signed or value.signed, self.fixed or value.fixed)
        if self.fixed and value.fixed:
            ret /= self.FIXED_BASE
        return ret
    __rmul__ = __mul__

    def __truediv__(self, value):
        value = ensure_expression(self.ebpf, value)
        myself = self
        if not self.fixed and value.fixed:
            myself *= self.FIXED_BASE ** 2
        elif self.fixed == value.fixed:
            myself *= self.FIXED_BASE

        return Binary(self.ebpf, myself, value, Opcode.DIV,
                      self.signed or value.signed, True)

    def _reverse(self, op, value):
        return op(Constant(self.ebpf, value), self)

    __rtruediv__ = lambda self, value: Constant(self.ebpf, value) / self

    def __floordiv__(self, value):
        value = ensure_expression(self.ebpf, value)
        myself = self
        if not self.fixed and value.fixed:
            myself *= self.FIXED_BASE
        elif self.fixed and not value.fixed:
            value *= self.FIXED_BASE

        return Binary(self.ebpf, myself, value, Opcode.DIV,
                      self.signed or value.signed, False)

    def __rfloordiv__(self, value):
        if self.fixed:
            value = Constant(self.ebpf, value)
            if not value.fixed:
                value *= self.FIXED_BASE
        else:
            value = Constant(self.ebpf, int(value))

        return Binary(self.ebpf, value, self, Opcode.DIV,
                      self.signed or value.signed, False)

    def __rshift__(self, value):
        opcode = Opcode.ARSH if self.signed else Opcode.RSH
        return Binary(self.ebpf, self, ensure_expression(self.ebpf, value),
                      opcode, self.signed, False)

    __rrshift__ = lambda self, value: Constant(self.ebpf, value) >> self

    def __and__(self, value):
        return AndExpression(self.ebpf, self,
                             ensure_expression(self.ebpf, value))

    def __eq__(self, value):
        return ~(self != value)

    __rand__ = __and__

    def __neg__(self):
        return Negate(self)

    def __abs__(self):
        return Absolute(self)

    def switch_endian(self, fmt):
        if isinstance(fmt, str) and len(fmt) > 1:
            return SwitchEndian(self, fmt)
        return self

    def __bool__(self):
        raise AssembleError("Expression only has a value at execution time")

    def __enter__(self):
        ret = self != 0
        self.as_comparison = ret
        return ret.__enter__()

    def __exit__(self, exc_type, exc, tb):
        return self.as_comparison.__exit__(exc_type, exc, tb)

    def load(self, dst, src, offset, fmt, long):
         self.ebpf.append(Opcode.LD + fmt_to_opcode(fmt), dst, src, offset, 0)
         if isinstance(fmt, str) and (fmt in "hb" or long and fmt == 'i'):
             shift = (64 if long else 32) - calcsize(fmt) * 8
             regs = self.ebpf.sr if long else self.ebpf.sw
             regs[dst] = (regs[dst] << shift) >> shift

    @contextmanager
    def calculate(self, dst, long, force=False):
        """issue the code that calculates the value of this expression

        this method returns three values:

        - the number of the register with the result
        - a boolean indicating whether this is a 64 bit value

        this method is a contextmanager to be used in a `with`
        statement. At the end of the `with` block the result is
        freed again, i.e. the register will not be reserved for the
        result anymore.

        the default implementation calls `get_address` for values
        which actually are in memory and moves that into a register.

        :param dst: the number of the register to put the result in,
           or `None` if that does not matter.
        :param long: True if the result is supposed to be 64 bit. None
           if it does not matter.
        :param force: if true, `dst` must be respected, otherwise this
           is optional.
        """
        with self.ebpf.get_free_register(dst) as dst:
            with self.get_address(dst, long) as (src, fmt):
                self.load(dst, src, 0, fmt, long)
                yield dst, long

    @contextmanager
    def get_address(self, dst, long, force=False):
        """get the address of the value of this expression

        this method returns the address of the result of this expression,
        and its format letter. The default implementation uses
        `calculate` to evaluate the expression and pushes the result onto
        the stack.
        """
        with self.ebpf.get_stack(4 + 4 * long) as stack:
            with self.calculate(dst, long) as (src, _):
                self.ebpf.append(Opcode.STX + Opcode.DW * long,
                                 10, src, stack, 0)
                self.ebpf.append(Opcode.MOV + Opcode.LONG + Opcode.REG,
                                 dst, 10, 0, 0)
                self.ebpf.append(Opcode.ADD + Opcode.LONG, dst, 0, 0, stack)
            yield dst, "Q" if long else "I"

    def contains(self, no):
        """return whether this expression contains the register `no`"""
        return False


class Binary(Expression):
    """represent all binary expressions"""
    def __init__(self, ebpf, left, right, operator, signed, fixed):
        self.ebpf = ebpf
        self.left = left
        self.right = right
        self.operator = operator
        self.signed = signed
        self.fixed = fixed

    @contextmanager
    def calculate(self, dst, long, force=False):
        orig_dst = dst
        if isinstance(self.right, Expression) and self.right.contains(dst):
            dst = None
        with self.ebpf.get_free_register(dst) as dst:
            with self.left.calculate(dst, long, True) as (dst, l_long):
                if long is None:
                    long = l_long
            if self.right.small_constant:
                self.ebpf.append(self.operator + Opcode.LONG * long,
                                 dst, 0, 0, int(self.right.value))
            else:
                with self.right.calculate(None, long) as (src, r_long):
                    self.ebpf.append(
                        self.operator + Opcode.REG
                        + Opcode.LONG * ((r_long or l_long)
                                         if long is None else long),
                        dst, src, 0, 0)
            if orig_dst is None or orig_dst == dst:
                yield dst, long
                return
        self.ebpf.append(Opcode.MOV + Opcode.REG + Opcode.LONG * long, orig_dst, dst, 0, 0)
        yield orig_dst, long

    def contains(self, no):
        return self.left.contains(no) or (isinstance(self.right, Expression)
                                          and self.right.contains(no))


class Unary(Expression):
    def __init__(self, arg):
        self.arg = arg
        self.ebpf = arg.ebpf
        self.signed = arg.signed
        self.fixed = arg.fixed

    @contextmanager
    def calculate(self, dst, long, force=False):
        with self.arg.calculate(dst, long, force) as (dst, long):
            self.calculate_unary(dst, long)
            yield dst, long

    def contains(self, no):
        return self.arg.contains(no)


class Negate(Unary):
    def __init__(self, arg):
        super().__init__(arg)
        self.signed = True

    def calculate_unary(self, dst, long):
        self.ebpf.append(Opcode.NEG + Opcode.LONG * long, dst, 0, 0, 0)


class Absolute(Unary):
    def __init__(self, arg):
        super().__init__(arg)
        self.signed = False

    def calculate_unary(self, dst, long):
        with self.ebpf.sr[dst] < 0:
            self.ebpf.sr[dst] = -self.ebpf.sr[dst]


class SwitchEndian(Unary):
    def __init__(self, arg, fmt):
        super().__init__(arg)
        self.fmt = fmt

    def calculate_unary(self, dst, long):
        endian, size = self.fmt
        if endian == "<":
            opcode = Opcode.LE
        elif endian in ">!":
            opcode = Opcode.BE
        self.ebpf.append(opcode, dst, 0, 0, calcsize(size) * 8)


class Sum(Binary):
    """represent the sum of one register and a constant value

    this is used to optimize memory addressing code.
    """
    def __init__(self, ebpf, left, right):
        super().__init__(ebpf, left, right, Opcode.ADD, right.value < 0, False)

    def __add__(self, value):
        try:
            self.right.value += index(value)
        except TypeError:
            return super().__add__(value)

    __radd__ = __add__

    def __sub__(self, value):
        try:
            self.right.value -= index(value)
        except TypeError:
            return super().__add__(value)


class AndExpression(Binary):
    # there is a special comparison with & instruction
    def __init__(self, ebpf, left, right):
        super().__init__(ebpf, left, right, Opcode.AND, False, False)

    def __ne__(self, value):
        if isinstance(value, int) and value == 0:
            return AndComparison(self.ebpf, self.left, self.right)
        return super().__ne__(value)


class AndComparison(SimpleComparison):
    # there is a special comparison with & instruction
    # it is the only one which has not inversion
    def __init__(self, ebpf, left, right):
        Binary.__init__(self, ebpf, left, right, Opcode.AND, False, False)
        SimpleComparison.__init__(self, ebpf, left, right, Opcode.JSET)
        self.opcode = (Opcode.JSET, None, Opcode.JSET, None)
        self.invert = None

    def compare(self, negative):
        super().compare(False)
        if negative:
            origin = len(self.ebpf.opcodes)
            self.ebpf.opcodes.append(None)
            self.target()
            self.origin = origin
            self.opcode = Opcode.JMP

    def __exit__(self, exc, etype, tb):
        super().__exit__(exc, etype, tb)
        if self.invert is not None:
            olen = len(self.ebpf.opcodes)
            assert self.ebpf.opcodes[self.invert].opcode == Opcode.JMP
            self.ebpf.opcodes[self.invert:self.invert] = \
                    self.ebpf.opcodes[self.else_origin+1:]
            del self.ebpf.opcodes[olen-1:]
            op, dst, src, off, imm = self.ebpf.opcodes[self.invert - 1]
            self.ebpf.opcodes[self.invert - 1] = \
                    Instruction(op, dst, src,
                                len(self.ebpf.opcodes) - self.else_origin + 1, imm)

    def Else(self):
        op, dst, src, off, imm = self.ebpf.opcodes[self.origin]
        if op is Opcode.JMP:
            self.invert = self.origin
        else:
            self.ebpf.opcodes[self.origin] = \
                Instruction(op, dst, src, off+1, imm)
        self.else_origin = len(self.ebpf.opcodes)
        self.ebpf.opcodes.append(None)
        return self

class Constant(Expression):
    def __init__(self, ebpf, value):
        try:
            self.value = index(value)
            self.fixed = False
        except TypeError:
            self.value = float(value) * Expression.FIXED_BASE
            self.fixed = True
        self.ebpf = ebpf
        self.signed = value < 0

    @property
    def small_constant(self):
        return -0x80000000 <= self.value < 0x80000000

    def __imul__(self, value):
        self.value *= value
        return self

    @contextmanager
    def calculate(self, dst, long, force=False):
        value = int(self.value)
        with self.ebpf.get_free_register(dst) as dst:
            if self.small_constant:
                self.ebpf.append(Opcode.MOV + Opcode.LONG, dst, 0, 0, value)
            else:
                self.ebpf.append(Opcode.DW, dst, 0, 0, value & 0xffffffff)
                self.ebpf.append(Opcode.W, 0, 0, 0, value >> 32)
            yield dst, not (-0x80000000 <= value < 0x100000000)

    def switch_endian(self, fmt):
        if not isinstance(fmt, str) or len(fmt) == 1:
            return self
        return Constant(self.ebpf, *unpack(fmt, pack(fmt[-1], self.value)))


class Register(Expression):
    """represent one EBPF register"""
    offset = 0

    def __init__(self, no, ebpf, long, signed, fixed=False):
        self.no = no
        self.ebpf = ebpf
        self.long = long
        self.signed = signed
        self.fixed = fixed

    def __add__(self, value):
        if self.long and not self.fixed:
            try:
                return Sum(self.ebpf, self, Constant(self.ebpf, index(value)))
            except TypeError:
                pass
        return super().__add__(value)

    __radd__ = __add__

    def __sub__(self, value):
        if self.long and not self.fixed:
            try:
                return Sum(self.ebpf, self, Constant(self.ebpf, -index(value)))
            except TypeError:
                pass
        return super().__sub__(value)

    @contextmanager
    def calculate(self, dst, long, force=False):
        if self.no not in self.ebpf.owners:
            raise AssembleError(f"register r{self.no} has no value")
        if dst != self.no and force:
            self.ebpf.append(Opcode.MOV + Opcode.REG + Opcode.LONG * self.long,
                             dst, self.no, 0, 0)
            yield dst, self.long
        else:
            yield self.no, self.long

    def contains(self, no):
        return self.no == no


class IAdd:
    """represent an in-place addition"""
    def __init__(self, ebpf, value):
        if isinstance(value, Expression):
            self.value = value
        else:
            self.value = Constant(ebpf, value)


def fmt_to_opcode(fmt):
    fmt_to_opcode = {'I': Opcode.W, 'H': Opcode.H, 'B': Opcode.B, 'Q': Opcode.DW,
                     'i': Opcode.W, 'h': Opcode.H, 'b': Opcode.B, 'q': Opcode.DW,
                     'A': Opcode.W, 'x': Opcode.DW}
    if isinstance(fmt, str):
        return fmt_to_opcode[fmt[-1]]
    else:
        return Opcode.B

class Memory(Expression):
    bits_to_opcode = {32: Opcode.W, 16: Opcode.H, 8: Opcode.B, 64: Opcode.DW}

    def __init__(self, ebpf, fmt, address):
        self.ebpf = ebpf
        self.fmt = fmt
        self.address = address

    def __iadd__(self, value):
        if self.fmt in "qQiIx":
            return IAdd(self.ebpf, value)
        else:
            return NotImplemented

    def __isub__(self, value):
        if self.fmt in "qQiIx":
            return IAdd(self.ebpf, -value)
        else:
            return NotImplemented

    @contextmanager
    def calculate(self, dst, long, force=False):
        if self.has_endian():
            with self.without_endian().switch_endian(self.fmt) \
                 .calculate(dst, long, force) as (dst, long):
                yield dst, long
                return
        with ExitStack() as exitStack:
            if isinstance(self.address, Sum):
                dst = exitStack.enter_context(self.ebpf.get_free_register(dst))
                self.load(dst, self.address.left.no, self.address.right.value,
                          self.fmt, long)
            else:
                dst, _ = exitStack.enter_context(
                    super().calculate(dst, long, force))
            if isinstance(self.fmt, tuple):
                self.ebpf.r[dst] &= ((1 << self.fmt[1]) - 1) << self.fmt[0]
                if self.fmt[0] > 0:
                    self.ebpf.r[dst] >>= self.fmt[0]
                yield dst, "B"
            else:
                yield dst, self.fmt[-1] in "QqAx"

    @contextmanager
    def get_address(self, dst, long, force=False):
        with self.address.calculate(dst, True) as (src, _):
            yield src, self.fmt

    def contains(self, no):
        return self.address.contains(no)

    @property
    def signed(self):
        return isinstance(self.fmt, str) and self.fmt.islower()

    @property
    def fixed(self):
        return isinstance(self.fmt, str) and self.fmt == "x"

    def has_endian(self):
        return isinstance(self.fmt, str) and len(self.fmt) > 1

    def without_endian(self):
        if self.has_endian():
            return Memory(self.ebpf, self.fmt[-1], self.address)
        return self

    def __invert__(self):
        if not isinstance(self.fmt, tuple) or self.fmt[1] != 1:
            return NotImplemented
        return self == 0

    def __ne__(self, value):
        if isinstance(self.fmt, tuple) and isinstance(value, int) \
                and value == 0:
            mask = ((1 << self.fmt[1]) - 1) << self.fmt[0]
            return Memory(self.ebpf, "B", self.address) & mask != 0
        return super().__ne__(value)

    def _set(self, value):
        opcode = Opcode.STX
        with ExitStack() as exitStack:
            if isinstance(self.fmt, tuple):
                pos, bits = self.fmt
                self.fmt = "B"
                if bits == 1:
                    try:
                        if value:
                            value = self | (1 << pos)
                        else:
                            value = self & ~(1 << pos)
                    except AssembleError:
                        exitStack.enter_context(self.ebpf.wtmp)
                        with value as Else:
                            self.ebpf.wtmp = self | (1 << pos)
                        with Else:
                            self.ebpf.wtmp = self & ~(1 << pos)
                        value = self.ebpf.wtmp
                else:
                    mask = ((1 << bits) - 1) << pos
                    value = (mask & (value << pos) | ~mask & self)
            elif isinstance(value, IAdd):
                value = value.value
                opcode = Opcode.XADD
            elif not isinstance(value, Expression):
                value = Constant(self.ebpf, value)
            if self.fmt == "x" and not value.fixed:
                value *= Expression.FIXED_BASE
            elif self.fmt != "x" and value.fixed:
                value /= Expression.FIXED_BASE
            if isinstance(self.address, Sum):
                dst = self.address.left.no
                offset = self.address.right.value
            else:
                dst, _ = exitStack.enter_context(
                    self.address.calculate(None, True))
                offset = 0
            value = value.switch_endian(self.fmt)
            if value.small_constant and opcode == Opcode.STX:
                self.ebpf.append(Opcode.ST + fmt_to_opcode(self.fmt), dst, 0,
                                 offset, int(value.value))
                return
            src, _ = exitStack.enter_context(
                value.calculate(None, isinstance(self.fmt, str)
                                      and self.fmt[-1] in 'qQx'))
            self.ebpf.append(opcode + fmt_to_opcode(self.fmt),
                             dst, src, offset, 0)



class MemoryDesc:
    """A base class used by descriptors for memory

    All memory access is relative to a base register. This is
    defined by the member variable `base_register` in deriving
    classes.
    """

    fixed = False  # only selected memory can have fixe value vars

    def __get__(self, instance, owner):
        if instance is None:
            return self
        fmt, addr = self.fmt_addr(instance)
        return Memory(instance.ebpf, fmt,
                      instance.ebpf.r[self.base_register] + addr)

    def __set__(self, instance, value):
        fmt, addr = self.fmt_addr(instance)
        memory = Memory(instance.ebpf, fmt,
                        instance.ebpf.r[self.base_register] + addr)
        memory._set(value)


[docs] class LocalVar(MemoryDesc): """variables on the stack This is how local variables on the stack are defined:: class Program(EBPF): my_variable = LocalVar('I') :param fmt: the data format as :mod:`struct` characters """ base_register = 10 def __init__(self, fmt='I'): self.fmt = fmt self.fixed = fmt == "x" def __set_name__(self, owner, name): size = fmtsize(self.fmt) owner.stack -= size owner.stack &= -size self.relative_addr = owner.stack self.name = name def fmt_addr(self, instance): if isinstance(instance, SubProgram): return (self.fmt, (instance.ebpf.stack & -8) + self.relative_addr) else: return self.fmt, self.relative_addr
[docs] class Structure: """combine values into a structure This is used for keys and values of hash tables. members of this structure are of class :class:`Member`:: class Point(Structure): x = Member('i') y = Member('i') """ stack = 0 base_register = 10 def __init__(self): self.data = bytearray(self.stack) def __repr__(self): return f"""{self.__class__.__name__}({ ', '.join(f'{k}={getattr(self, k)}' for k, v in self.__class__.__dict__.items() if isinstance(v, Member))})"""
[docs] class Member(LocalVar): """The member of a :class:`Structure` :param fmt: the data format as :mod:`struct` characters """ def __set_name__(self, owner, name): size = fmtsize(self.fmt) if owner.stack & (size - 1): raise AssembleError("structures must be packed") self.relative_addr = owner.stack owner.stack += size self.name = name def fmt_addr(self, instance): fmt, addr = super().fmt_addr(instance) return fmt, addr + instance.addr_offset def __get__(self, instance, owner): if instance is None: return self elif instance.data is None: self.base_register = instance.base_register return super().__get__(instance, owner) return unpack_from(self.fmt, instance.data, self.relative_addr)[0] def __set__(self, instance, value): if instance.data is None: self.base_register = instance.base_register return super().__set__(instance, value) pack_into(self.fmt, instance.data, self.relative_addr, value)
class MemoryMap: def __init__(self, ebpf, fmt): self.ebpf = ebpf self.fmt = fmt def __setitem__(self, addr, value): memory = Memory(self.ebpf, self.fmt, addr) memory._set(value) def __getitem__(self, addr): if isinstance(addr, Register): addr = addr + 0 return Memory(self.ebpf, self.fmt, addr) class Map(ABC): """The base class for all maps""" def __set_name__(self, owner, name): self.filename = name @abstractmethod def init(self, ebpf, fd): """create the map and initialize its values :param fd: the file descriptor for an already existing map, or ``None`` if it need to be created. """ def load(self, ebpf): """called after the program has been loaded""" class PseudoFd(Expression): """represent a file descriptor to a map""" def __init__(self, ebpf, fd): self.ebpf = ebpf self.fd = fd self.fixed = False @contextmanager def calculate(self, dst, long, force=False): with self.ebpf.get_free_register(dst) as dst: self.ebpf.append(Opcode.DW, dst, 1, 0, self.fd) self.ebpf.append(Opcode.W, 0, 0, 0, 0) yield dst, long
[docs] class ktime(Expression): """a function that returns the current ktime in ns""" signed = False def __init__(self, ebpf): self.ebpf = ebpf self.fixed = False @contextmanager def calculate(self, dst, long, force=False): with self.ebpf.get_free_register(dst) as dst: with self.ebpf.save_registers([i for i in range(6) if i != dst]): self.ebpf.call(FuncId.ktime_get_ns) if dst != 0: self.ebpf.r[dst] = self.ebpf.r0 yield dst, True
[docs] class prandom(Expression): """a function that returns the current ktime in ns""" def __init__(self, ebpf): self.ebpf = ebpf @contextmanager def calculate(self, dst, long, force=False): with self.ebpf.get_free_register(dst) as dst: with self.ebpf.save_registers([i for i in range(6) if i != dst]): self.ebpf.call(FuncId.get_prandom_u32) if dst != 0: self.ebpf.r[dst] = self.ebpf.r0 yield dst, True
class RegisterDesc: def __init__(self, no, array): self.no = no self.array = array def __get__(self, instance, owner=None): if instance is None: return self else: return getattr(instance, self.array)[self.no] def __set__(self, instance, value): getattr(instance, self.array)[self.no] = value class RegisterArray: def __init__(self, ebpf, long, signed, fixed=False): self.ebpf = ebpf self.long = long self.signed = signed self.fixed = fixed def __setitem__(self, no, value): self.ebpf.owners.add(no) value = ensure_expression(self.ebpf, value) if self.fixed and not value.fixed: value *= Expression.FIXED_BASE elif not self.fixed and value.fixed: value /= Expression.FIXED_BASE with value.calculate(no, self.long, True): pass def __getitem__(self, no): return Register(no, self.ebpf, self.long, self.signed, self.fixed) class Temporary(Register): def __init__(self, ebpf, long, signed, fixed): super().__init__(None, ebpf, long, signed, fixed) self.nos = [] self.gfrs = [] def __enter__(self): gfr = self.ebpf.get_free_register(None) self.nos.append(self.no) self.no = gfr.__enter__() self.gfrs.append(gfr) def __exit__(self, a, b, c): gfr = self.gfrs.pop() gfr.__exit__(a, b, c) self.no = self.nos.pop() class TemporaryDesc(RegisterDesc): def __set_name__(self, owner, name): self.name = name def __get__(self, instance, owner=None): if instance is None: return self arr = getattr(instance, self.array) ret = instance.__dict__.get(self.name, None) if ret is None: ret = instance.__dict__[self.name] = \ Temporary(instance, arr.long, arr.signed, arr.fixed) return ret def __set__(self, instance, value): no = getattr(instance, self.name).no getattr(instance, self.array)[no] = value class EBPFBase: name = None def __init__(self, name=None, subprograms=()): if name is None: if self.name is None: self.name = self.__class__.__name__ else: self.name = name self.subprograms = subprograms for p in subprograms: p.ebpf = self @property def ebpf(self): return self class SimulatedEBPF(EBPFBase): loaded = True def __init__(self, **kwargs): super().__init__(**kwargs) unique = set() for cls in self.__class__.__mro__: for k, v in cls.__dict__.items(): if k not in unique and isinstance(v, Map): size = v.collect(self) setattr(self, k, self.get_array(size)) unique.add(k)
[docs] class EBPF(EBPFBase): """The base class for all EBPF programs Usually this class is sub-classed, and the actual program is defined in the overwritten `program` method. Then the program may be loaded into the kernel. Alternatively, this class may even be instantiated directly, in which case you can just issue the program before it is loaded. After a program is loaded, its maps may be written to a bpf file system using :meth:`pin_maps`. Those maps may be used at a later time, especially also in a different task, if the parameter ``load_maps`` is given, in which case we assume the program has already been loaded. :param load_maps: a prefix to load pinned maps from. Must be existing in a bpf file system, and usually ends in a "/". """ stack = 0 license = None def __init__(self, prog_type=0, license=None, *, kern_version=0, load_maps=None, **kwargs): self.opcodes = [] self.prog_type = prog_type if license is not None: self.license = license self.kern_version = kern_version self.loaded = load_maps is not None self.mB = MemoryMap(self, "B") self.mH = MemoryMap(self, "H") self.mI = MemoryMap(self, "I") self.mA = MemoryMap(self, "A") # actually I, but treat as Q self.mQ = MemoryMap(self, "Q") self.mb = MemoryMap(self, "b") self.mh = MemoryMap(self, "h") self.mi = MemoryMap(self, "i") self.mq = MemoryMap(self, "q") self.mx = MemoryMap(self, "x") self.r = RegisterArray(self, True, False) self.sr = RegisterArray(self, True, True) self.w = RegisterArray(self, False, False) self.sw = RegisterArray(self, False, True) self.x = RegisterArray(self, True, True, True) self.owners = {1, 10} super().__init__(**kwargs) for k, v in self.__class__.__dict__.items(): if isinstance(v, Map): if load_maps is None: v.init(self, None) else: v.init(self, bpf.obj_get(load_maps + k))
[docs] def pin_maps(self, path): """pin the maps of this program to files with prefix `path` This path must be in a bpf file system, and all parent directories must already exist, while the individual files must not exist. """ for k, v in self.__class__.__dict__.items(): if isinstance(v, Map): bpf.obj_pin(path + k, getattr(self, v.name).fd)
[docs] def program(self): """overwrite this method with your program while subclassing"""
def append(self, opcode, dst, src, off, imm): self.opcodes.append(Instruction(opcode, dst, src, off, imm)) def append_endian(self, fmt, dst): if not isinstance(fmt, str) or len(fmt) != 2: return endian, size = fmt if endian == "<": opcode = Opcode.LE elif endian in ">!": opcode = Opcode.BE self.append(opcode, dst, 0, 0, calcsize(fmt) * 8)
[docs] def assemble(self): """return the assembled program""" sub(EBPF, self).program() return b"".join( pack("<BBHI", i.opcode.value, i.dst | i.src << 4, i.off % 0x10000, i.imm % 0x100000000) for i in self.opcodes)
[docs] def load(self, log_level=0, log_size=10 * 4096): """load the program into the kernel""" name = ''.join(c if c in bpf.allowed_chars else '_' for c in self.name[:15]) fd, log = bpf.prog_load(self.prog_type, self.assemble(), self.license, log_level, log_size, self.kern_version, name=name) self.loaded = True self.file_descriptor = fd for v in self.__class__.__dict__.values(): if isinstance(v, Map): v.load(self) return log
def close(self): os.close(self.file_descriptor) self.file_descriptor = None def test_run(self, *args, **kwargs): return bpf.prog_test_run(self.file_descriptor, *args, **kwargs)
[docs] def jumpIf(self, comp): """jump if `comp` is true to a later defined `target`""" if isinstance(comp, Expression): comp = comp != 0 comp.compare(False) return comp
[docs] def jump(self): """unconditionally jump to a later defined `target`""" comp = SimpleComparison(self, None, 0, Opcode.JMP) comp.origin = len(self.opcodes) comp.dst = 0 comp.owners = self.owners.copy() self.owners = set(range(11)) self.opcodes.append(None) return comp
[docs] def get_fd(self, fd): """return the file descriptor `fd` of a map""" return PseudoFd(self, fd)
[docs] def call(self, no): """call the kernel function `no` from enum `FuncId`""" assert isinstance(no, FuncId) self.append(Opcode.CALL, 0, 0, 0, no.value) self.owners.add(0) self.owners -= set(range(1, 6))
[docs] def exit(self, no=None): """Exit the program with return value `no`""" if no is not None: self.r0 = no.value self.append(Opcode.EXIT, 0, 0, 0, 0)
@contextmanager def get_free_register(self, dst): if dst is not None: yield dst return for i in range(10): if i not in self.owners: self.owners.add(i) yield i self.owners.discard(i) return raise AssembleError("not enough registers") @contextmanager def save_registers(self, registers): oldowners = self.owners.copy() registers = set(registers) self.owners |= registers save = [] with ExitStack() as exitStack: for i in registers: if i in oldowners: tmp = exitStack.enter_context(self.get_free_register(None)) self.append(Opcode.MOV+Opcode.LONG+Opcode.REG, tmp, i, 0, 0) save.append((tmp, i)) yield for tmp, i in save: self.append(Opcode.MOV+Opcode.LONG+Opcode.REG, i, tmp, 0, 0) self.owners -= registers @contextmanager def get_stack(self, size): oldstack = self.stack self.stack = (self.stack - size) & -size yield self.stack self.stack = oldstack tmp = TemporaryDesc(None, "r") stmp = TemporaryDesc(None, "sr") wtmp = TemporaryDesc(None, "w") swtmp = TemporaryDesc(None, "sw") xtmp = TemporaryDesc(None, "x")
for i in range(11): setattr(EBPF, f"r{i}", RegisterDesc(i, "r")) for i in range(10): setattr(EBPF, f"sr{i}", RegisterDesc(i, "sr")) for i in range(10): setattr(EBPF, f"w{i}", RegisterDesc(i, "w")) for i in range(10): setattr(EBPF, f"sw{i}", RegisterDesc(i, "sw")) for i in range(10): setattr(EBPF, f"x{i}", RegisterDesc(i, "x")) class SubProgram: stack = 0