Skip to content

Commit 831aeb0

Browse files
authored
Update abstracts.py
1 parent be0f579 commit 831aeb0

File tree

1 file changed

+79
-67
lines changed

1 file changed

+79
-67
lines changed

lib/cuckoo/common/abstracts.py

Lines changed: 79 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,8 @@ def __init__(self):
437437
)
438438

439439
super().__init__()
440+
self.conn = None
441+
self.conn_lock = threading.Lock()
440442

441443
def _initialize_check(self):
442444
"""Runs all checks when a machine manager is initialized.
@@ -446,9 +448,6 @@ def _initialize_check(self):
446448
if not self._version_check():
447449
raise CuckooMachineError("Libvirt version is not supported, please get an updated version")
448450

449-
# Preload VMs
450-
self.vms = self._fetch_machines()
451-
452451
# Base checks. Also attempts to shutdown any machines which are
453452
# currently still active.
454453
super()._initialize_check()
@@ -474,34 +473,33 @@ def start(self, label=None):
474473

475474
conn = self._connect(label)
476475

477-
snapshot_list = self.vms[label].snapshotListNames(flags=0)
476+
try:
477+
vm = conn.lookupByName(label)
478+
except libvirt.libvirtError as e:
479+
raise CuckooMachineError(f"Cannot find machine {label}") from e
478480

479-
# If a snapshot is configured try to use it.
480-
if vm_info.snapshot and vm_info.snapshot in snapshot_list:
481-
# Revert to desired snapshot, if it exists.
482-
log.debug("Using snapshot %s for virtual machine %s", vm_info.snapshot, label)
483-
try:
484-
vm = self.vms[label]
481+
snapshot = None
482+
try:
483+
snapshot_list = vm.snapshotListNames(flags=0)
484+
485+
# If a snapshot is configured try to use it.
486+
if vm_info.snapshot and vm_info.snapshot in snapshot_list:
487+
log.debug("Using snapshot %s for virtual machine %s", vm_info.snapshot, label)
485488
snapshot = vm.snapshotLookupByName(vm_info.snapshot, flags=0)
486-
self.vms[label].revertToSnapshot(snapshot, flags=0)
487-
except libvirt.libvirtError as e:
488-
msg = f"Unable to restore snapshot {vm_info.snapshot} on virtual machine {label}. Your snapshot MUST BE in running state!"
489-
raise CuckooMachineError(msg) from e
490-
finally:
491-
self._disconnect(conn)
492-
elif self._get_snapshot(label):
493-
snapshot = self._get_snapshot(label)
494-
log.debug("Using snapshot %s for virtual machine %s", snapshot.getName(), label)
495-
try:
496-
self.vms[label].revertToSnapshot(snapshot, flags=0)
497-
except libvirt.libvirtError as e:
498-
raise CuckooMachineError(f"Unable to restore snapshot on virtual machine {label}. Your snapshot MUST BE in running state!") from e
499-
finally:
500-
self._disconnect(conn)
501-
else:
502-
self._disconnect(conn)
489+
else:
490+
snapshot = self._get_snapshot(label, vm)
491+
except libvirt.libvirtError:
492+
log.warning("Unable to fetch snapshot for virtual machine %s", label)
493+
494+
if not snapshot:
503495
raise CuckooMachineError(f"No snapshot found for virtual machine {label}")
504496

497+
log.debug("Using snapshot %s for virtual machine %s", snapshot.getName(), label)
498+
try:
499+
vm.revertToSnapshot(snapshot, flags=0)
500+
except libvirt.libvirtError as e:
501+
raise CuckooMachineError(f"Unable to restore snapshot on virtual machine {label}. Your snapshot MUST BE in running state!") from e
502+
505503
# Check state.
506504
self._wait_status(label, self.RUNNING)
507505

@@ -521,14 +519,14 @@ def stop(self, label=None):
521519
# Force virtual machine shutdown.
522520
conn = self._connect(label)
523521
try:
524-
if not self.vms[label].isActive():
522+
vm = conn.lookupByName(label)
523+
if not vm.isActive():
525524
log.debug("Trying to stop an already stopped machine %s, skipping", label)
526525
else:
527-
self.vms[label].destroy() # Machete's way!
526+
vm.destroy() # Machete's way!
528527
except libvirt.libvirtError as e:
529528
raise CuckooMachineError(f"Error stopping virtual machine {label}: {e}") from e
530-
finally:
531-
self._disconnect(conn)
529+
532530
# Check state.
533531
self._wait_status(label, self.POWEROFF)
534532

@@ -544,8 +542,14 @@ def shutdown(self):
544542
except CuckooMachineError as e:
545543
log.warning("Unable to shutdown machine %s, please check manually. Error: %s", machine.label, e)
546544

547-
# Free handlers.
548-
self.vms = None
545+
# Close connection
546+
with self.conn_lock:
547+
if self.conn:
548+
try:
549+
self.conn.close()
550+
except libvirt.libvirtError:
551+
pass
552+
self.conn = None
549553

550554
def screenshot(self, label, path):
551555
"""Screenshot a running virtual machine.
@@ -587,11 +591,10 @@ def dump_memory(self, label, path):
587591
# it'll still be owned by root, so we can't delete it, but at least we can read it
588592
fd = open(path, "w")
589593
fd.close()
590-
self.vms[label].coreDump(path, flags=libvirt.VIR_DUMP_MEMORY_ONLY)
594+
vm = conn.lookupByName(label)
595+
vm.coreDump(path, flags=libvirt.VIR_DUMP_MEMORY_ONLY)
591596
except libvirt.libvirtError as e:
592597
raise CuckooMachineError(f"Error dumping memory virtual machine {label}: {e}") from e
593-
finally:
594-
self._disconnect(conn)
595598

596599
def _status(self, label):
597600
"""Gets current status of a vm.
@@ -613,11 +616,10 @@ def _status(self, label):
613616

614617
conn = self._connect(label)
615618
try:
616-
state = self.vms[label].state(flags=0)
619+
vm = conn.lookupByName(label)
620+
state = vm.state(flags=0)
617621
except libvirt.libvirtError as e:
618622
raise CuckooMachineError(f"Error getting status for virtual machine {label}: {e}") from e
619-
finally:
620-
self._disconnect(conn)
621623

622624
if state:
623625
if state[0] == 1:
@@ -644,29 +646,37 @@ def _connect(self, label=None):
644646
if not self.dsn:
645647
raise CuckooMachineError("You must provide a proper connection string")
646648

647-
try:
648-
return libvirt.open(self.dsn)
649-
except libvirt.libvirtError as e:
650-
raise CuckooMachineError("Cannot connect to libvirt") from e
649+
with self.conn_lock:
650+
if self.conn:
651+
try:
652+
if self.conn.isAlive():
653+
return self.conn
654+
except libvirt.libvirtError:
655+
pass
656+
657+
# Connection is dead
658+
try:
659+
self.conn.close()
660+
except libvirt.libvirtError:
661+
pass
662+
self.conn = None
663+
664+
try:
665+
self.conn = libvirt.open(self.dsn)
666+
except libvirt.libvirtError as e:
667+
raise CuckooMachineError("Cannot connect to libvirt") from e
651668

652-
def _disconnect(self, conn):
669+
return self.conn
670+
671+
def _disconnect(self, _conn):
653672
"""Disconnects to libvirt subsystem.
654673
@raise CuckooMachineError: if cannot disconnect from libvirt.
655674
"""
656-
try:
657-
conn.close()
658-
except libvirt.libvirtError as e:
659-
raise CuckooMachineError("Cannot disconnect from libvirt") from e
660-
661-
def _fetch_machines(self):
662-
"""Fetch machines handlers.
663-
@return: dict with machine label as key and handle as value.
664-
"""
665-
return {vm.label: self._lookup(vm.label) for vm in self.machines()}
675+
# Do nothing, keep connection open for reuse
676+
pass
666677

667678
def _lookup(self, label):
668679
"""Search for a virtual machine.
669-
@param conn: libvirt connection handle.
670680
@param label: virtual machine name.
671681
@raise CuckooMachineError: if virtual machine is not found.
672682
"""
@@ -675,8 +685,6 @@ def _lookup(self, label):
675685
vm = conn.lookupByName(label)
676686
except libvirt.libvirtError as e:
677687
raise CuckooMachineError(f"Cannot find machine {label}") from e
678-
finally:
679-
self._disconnect(conn)
680688
return vm
681689

682690
def _list(self):
@@ -685,22 +693,31 @@ def _list(self):
685693
"""
686694
conn = self._connect()
687695
try:
696+
if hasattr(conn, "listAllDomains"):
697+
# flags=0 returns all domains (active and inactive)
698+
return [dom.name() for dom in conn.listAllDomains(0)]
699+
700+
# Fallback for older libvirt versions
688701
names = conn.listDefinedDomains()
702+
for vid in conn.listDomainsID():
703+
try:
704+
dom = conn.lookupByID(vid)
705+
names.append(dom.name())
706+
except libvirt.libvirtError:
707+
continue
708+
return names
689709
except libvirt.libvirtError as e:
690710
raise CuckooMachineError("Cannot list domains") from e
691-
finally:
692-
self._disconnect(conn)
693-
return names
694-
695711
def _version_check(self):
696712
"""Check if libvirt release supports snapshots.
697713
@return: True or false.
698714
"""
699715
return libvirt.getVersion() >= 8000
700716

701-
def _get_snapshot(self, label):
717+
def _get_snapshot(self, label, vm):
702718
"""Get current snapshot for virtual machine
703719
@param label: virtual machine name
720+
@param vm: virtual machine handle
704721
@return None or current snapshot
705722
@raise CuckooMachineError: if cannot find current snapshot or
706723
when there are too many snapshots available
@@ -715,10 +732,7 @@ def _extract_creation_time(node):
715732
return xml.findtext("./creationTime")
716733

717734
snapshot = None
718-
conn = self._connect(label)
719735
try:
720-
vm = self.vms[label]
721-
722736
# Try to get the currrent snapshot, otherwise fallback on the latest
723737
# from config file.
724738
if vm.hasCurrentSnapshot(flags=0):
@@ -732,8 +746,6 @@ def _extract_creation_time(node):
732746
snapshot = sorted(all_snapshots, key=_extract_creation_time, reverse=True)[0]
733747
except libvirt.libvirtError:
734748
raise CuckooMachineError(f"Unable to get snapshot for virtual machine {label}")
735-
finally:
736-
self._disconnect(conn)
737749

738750
return snapshot
739751

0 commit comments

Comments
 (0)