Skip to content

Commit

Permalink
binja: compute call graph up front
Browse files Browse the repository at this point in the history
for cache friendliness. see #2402
  • Loading branch information
williballenthin committed Nov 27, 2024
1 parent a909d02 commit 319dbfe
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 78 deletions.
61 changes: 59 additions & 2 deletions capa/features/extractors/binja/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
import logging
from typing import Iterator
from collections import defaultdict

import binaryninja as binja
from binaryninja import ILException
from binaryninja import Function, BinaryView, SymbolType, ILException, RegisterValueType, LowLevelILOperation

import capa.perf
import capa.features.extractors.elf
import capa.features.extractors.binja.file
import capa.features.extractors.binja.insn
Expand All @@ -26,6 +29,8 @@
StaticFeatureExtractor,
)

logger = logging.getLogger(__name__)


class BinjaFeatureExtractor(StaticFeatureExtractor):
def __init__(self, bv: binja.BinaryView):
Expand All @@ -36,6 +41,9 @@ def __init__(self, bv: binja.BinaryView):
self.global_features.extend(capa.features.extractors.binja.global_.extract_os(self.bv))
self.global_features.extend(capa.features.extractors.binja.global_.extract_arch(self.bv))

with capa.perf.timing("binary ninja: computing call graph"):
self._call_graph = self._build_call_graph()

def get_base_address(self):
return AbsoluteVirtualAddress(self.bv.start)

Expand All @@ -45,9 +53,58 @@ def extract_global_features(self):
def extract_file_features(self):
yield from capa.features.extractors.binja.file.extract_features(self.bv)

def _build_call_graph(self):
# from function address to function addresses
calls_from: defaultdict[int, set[int]] = defaultdict(set)
calls_to: defaultdict[int, set[int]] = defaultdict(set)

f: Function
for f in self.bv.functions:
bv: BinaryView = f.view

for bbil in f.llil:
for llil in bbil:
if llil.operation not in (
LowLevelILOperation.LLIL_CALL,
LowLevelILOperation.LLIL_CALL_STACK_ADJUST,
LowLevelILOperation.LLIL_JUMP,
LowLevelILOperation.LLIL_TAILCALL,
):
continue

if llil.dest.value.type not in (
RegisterValueType.ImportedAddressValue,
RegisterValueType.ConstantValue,
RegisterValueType.ConstantPointerValue,
):
continue

address = llil.dest.value.value

for sym in bv.get_symbols(address):
if not sym:
continue

if sym.type not in (
SymbolType.ImportAddressSymbol,
SymbolType.ImportedFunctionSymbol,
SymbolType.FunctionSymbol,
):
continue

calls_from[f.start].add(address)
calls_to[address].add(f.start)

call_graph = {
"calls_to": calls_to,
"calls_from": calls_from,
}

return call_graph

def get_functions(self) -> Iterator[FunctionHandle]:
for f in self.bv.functions:
yield FunctionHandle(address=AbsoluteVirtualAddress(f.start), inner=f)
yield FunctionHandle(address=AbsoluteVirtualAddress(f.start), inner=f, ctx={"call_graph": self._call_graph})

def extract_function_features(self, fh: FunctionHandle) -> Iterator[tuple[Feature, Address]]:
yield from capa.features.extractors.binja.function.extract_features(fh)
Expand Down
57 changes: 24 additions & 33 deletions capa/features/extractors/binja/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# See the License for the specific language governing permissions and limitations under the License.
from typing import Iterator

from binaryninja import Function, BinaryView, SymbolType, ILException, RegisterValueType, LowLevelILOperation
from binaryninja import Function, BinaryView, SymbolType

from capa.features.file import FunctionName
from capa.features.common import Feature, Characteristic
Expand All @@ -20,38 +20,24 @@ def extract_function_calls_to(fh: FunctionHandle):
"""extract callers to a function"""
func: Function = fh.inner

for caller in func.caller_sites:
# Everything that is a code reference to the current function is considered a caller, which actually includes
# many other references that are NOT a caller. For example, an instruction `push function_start` will also be
# considered a caller to the function
llil = None
try:
# Temporary fix for https://github.com/Vector35/binaryninja-api/issues/6020. Since `.llil` can throw an
# exception rather than returning None
llil = caller.llil
except ILException:
caller: int
for caller in fh.ctx["call_graph"].get("calls_to", []):
if caller == func.start:
continue

if (llil is None) or llil.operation not in [
LowLevelILOperation.LLIL_CALL,
LowLevelILOperation.LLIL_CALL_STACK_ADJUST,
LowLevelILOperation.LLIL_JUMP,
LowLevelILOperation.LLIL_TAILCALL,
]:
continue
yield Characteristic("calls to"), AbsoluteVirtualAddress(caller)

if llil.dest.value.type not in [
RegisterValueType.ImportedAddressValue,
RegisterValueType.ConstantValue,
RegisterValueType.ConstantPointerValue,
]:
continue

address = llil.dest.value.value
if address != func.start:
def extract_function_calls_from(fh: FunctionHandle):
"""extract callers from a function"""
func: Function = fh.inner

callee: int
for callee in fh.ctx["call_graph"].get("calls_from", []):
if callee == func.start:
continue

yield Characteristic("calls to"), AbsoluteVirtualAddress(caller.address)
yield Characteristic("calls from"), AbsoluteVirtualAddress(callee)


def extract_function_loop(fh: FunctionHandle):
Expand All @@ -72,13 +58,12 @@ def extract_function_loop(fh: FunctionHandle):
def extract_recursive_call(fh: FunctionHandle):
"""extract recursive function call"""
func: Function = fh.inner
bv: BinaryView = func.view
if bv is None:
return

for ref in bv.get_code_refs(func.start):
if ref.function == func:
caller: int
for caller in fh.ctx["call_graph"].get("calls_to", []):
if caller == func.start:
yield Characteristic("recursive call"), fh.address
return


def extract_function_name(fh: FunctionHandle):
Expand Down Expand Up @@ -108,4 +93,10 @@ def extract_features(fh: FunctionHandle) -> Iterator[tuple[Feature, Address]]:
yield feature, addr


FUNCTION_HANDLERS = (extract_function_calls_to, extract_function_loop, extract_recursive_call, extract_function_name)
FUNCTION_HANDLERS = (
extract_function_calls_to,
extract_function_calls_from,
extract_function_loop,
extract_recursive_call,
extract_function_name,
)
44 changes: 1 addition & 43 deletions capa/features/extractors/binja/insn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import capa.features.extractors.helpers
from capa.features.insn import API, MAX_STRUCTURE_SIZE, Number, Offset, Mnemonic, OperandNumber, OperandOffset
from capa.features.common import MAX_BYTES_FEATURE_SIZE, Bytes, String, Feature, Characteristic
from capa.features.address import Address, AbsoluteVirtualAddress
from capa.features.address import Address
from capa.features.extractors.binja.helpers import DisassemblyInstruction, visit_llil_exprs
from capa.features.extractors.base_extractor import BBHandle, InsnHandle, FunctionHandle

Expand Down Expand Up @@ -500,47 +500,6 @@ def extract_insn_cross_section_cflow(
yield Characteristic("cross section flow"), ih.address


def extract_function_calls_from(fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle) -> Iterator[tuple[Feature, Address]]:
"""extract functions calls from features
most relevant at the function scope, however, its most efficient to extract at the instruction scope
"""
func: Function = fh.inner
bv: BinaryView = func.view

if bv is None:
return

for il in func.get_llils_at(ih.address):
if il.operation not in [
LowLevelILOperation.LLIL_CALL,
LowLevelILOperation.LLIL_CALL_STACK_ADJUST,
LowLevelILOperation.LLIL_TAILCALL,
]:
continue

dest = il.dest
if dest.operation == LowLevelILOperation.LLIL_CONST_PTR:
value = dest.value.value
yield Characteristic("calls from"), AbsoluteVirtualAddress(value)
elif dest.operation == LowLevelILOperation.LLIL_CONST:
yield Characteristic("calls from"), AbsoluteVirtualAddress(dest.value)
elif dest.operation == LowLevelILOperation.LLIL_LOAD:
indirect_src = dest.src
if indirect_src.operation == LowLevelILOperation.LLIL_CONST_PTR:
value = indirect_src.value.value
yield Characteristic("calls from"), AbsoluteVirtualAddress(value)
elif indirect_src.operation == LowLevelILOperation.LLIL_CONST:
yield Characteristic("calls from"), AbsoluteVirtualAddress(indirect_src.value)
elif dest.operation == LowLevelILOperation.LLIL_REG:
if dest.value.type in [
RegisterValueType.ImportedAddressValue,
RegisterValueType.ConstantValue,
RegisterValueType.ConstantPointerValue,
]:
yield Characteristic("calls from"), AbsoluteVirtualAddress(dest.value.value)


def extract_function_indirect_call_characteristic_features(
fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
) -> Iterator[tuple[Feature, Address]]:
Expand Down Expand Up @@ -590,6 +549,5 @@ def extract_features(f: FunctionHandle, bbh: BBHandle, insn: InsnHandle) -> Iter
extract_insn_peb_access_characteristic_features,
extract_insn_cross_section_cflow,
extract_insn_segment_access_features,
extract_function_calls_from,
extract_function_indirect_call_characteristic_features,
)

0 comments on commit 319dbfe

Please sign in to comment.