File: //opt/imunify360/venv/lib/python3.11/site-packages/imav/malwarelib/plugins/patch_vulnerabilities.py
"""
This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License,
or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Copyright © 2019 Cloud Linux Software Inc.
This software is also available under ImunifyAV commercial license,
see <https://www.imunify360.com/legal/eula>
"""
import logging
import time
import queue
import uuid
from collections import defaultdict
from collections.abc import Hashable
from defence360agent.api import inactivity
from defence360agent.contracts.messages import MessageType
from defence360agent.contracts.plugins import (
MessageSink,
MessageSource,
expect,
)
from defence360agent.utils import (
batched,
nice_iterator,
recurring_check,
safe_cancel_task,
)
from imav.malwarelib.config import VulnerabilityHitStatus
from imav.malwarelib.model import VulnerabilityHit
from imav.malwarelib.vulnerabilities.patcher import (
PatchResult,
VulnerabilityPatcher,
)
from imav.malwarelib.vulnerabilities.storage import PatchStorage
logger = logging.getLogger(__name__)
class PatchQueue:
def __init__(self):
self._queue = defaultdict(set)
def put(self, key: Hashable, values: set):
self._queue[key] |= values
def get(self) -> tuple[Hashable, set]:
try:
return self._queue.popitem()
except KeyError as exc:
raise queue.Empty() from exc
def empty(self) -> bool:
return not bool(self._queue)
class Patch(MessageSink, MessageSource):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._queue = PatchQueue()
self._loop = None
self._sink = None
self._patcher = None
self._patch_task = None
async def create_sink(self, loop):
pass
async def create_source(self, loop, sink):
self._loop = loop
self._sink = sink
self._patcher = VulnerabilityPatcher(loop=loop, sink=sink)
self._patch_task = loop.create_task(self.recurring_patch())
async def shutdown(self):
if self._patch_task:
await safe_cancel_task(self._patch_task)
@expect(MessageType.VulnerabilityPatchTask)
async def process_patch_task(self, message: dict):
source = (
message.get("cause"),
message.get("initiator"),
message.get("manual", False),
)
files_to_patch = message.get("filelist", [])
self._queue.put(source, set(files_to_patch))
async def _patch_vulnerabilities(self):
while not self._queue.empty():
[cause, initiator, manual], files_to_patch = self._queue.get()
for files_batch in batched(files_to_patch, n=10_000):
with inactivity.track.task("patch_vulnerabilities"):
vulnerable_statuses = [VulnerabilityHitStatus.VULNERABLE]
if manual:
vulnerable_statuses.append(
VulnerabilityHitStatus.REVERTED
)
hits = VulnerabilityHit.select().where(
VulnerabilityHit.orig_file.in_(files_batch),
VulnerabilityHit.status.in_(vulnerable_statuses),
)
(
succeeded,
failed,
not_exist,
) = await PatchStorage.store_all(hits)
if failed:
for hit in failed:
await self._sink.process_message(
MessageType.VulnerabilityPatchFailed(
message=(
"Failed to store the original from {}"
" to {}".format(
hit.orig_file, PatchStorage.path
)
),
timestamp=int(time.time()),
)
)
if not_exist:
VulnerabilityHit.delete_hits(
[hit.orig_file for hit in not_exist]
)
user_hits = VulnerabilityHit.group_by_attribute(
succeeded,
attribute="owner",
)
for user, hits in user_hits.items():
started = time.time()
files = [hit.orig_file for hit in hits]
# update status to avoid any races
VulnerabilityHit.set_status(
hits, VulnerabilityHitStatus.PATCH_IN_PROGRESS
)
result, error, cmd = await self._patcher.start(
user, files
)
await self._sink.process_message(
MessageType.VulnerabilityPatch(
hits=hits,
result=result,
cleanup_id=uuid.uuid4().hex,
started=started,
error=error,
cause=cause,
initiator=initiator,
args=cmd,
)
)
@recurring_check(1)
async def recurring_patch(self):
if not self._queue.empty():
await self._patch_vulnerabilities()
class PatchResultProcessor(MessageSink):
async def create_sink(self, loop):
pass
@staticmethod
def _set_hit_status(
hits: list[VulnerabilityHit], status: str, patched_at=None
):
VulnerabilityHit.set_status(hits, status, patched_at)
for hit in hits:
hit.status = status
hit.patched_at = patched_at
@expect(MessageType.VulnerabilityPatch)
async def process_patch_result(self, message: dict):
hits: list[VulnerabilityHit] = message["hits"]
result: PatchResult = message["result"]
now = time.time()
processed = [hit for hit in hits if hit in result]
unprocessed = [hit for hit in hits if hit not in result]
not_exist = []
async for hit in nice_iterator(processed, chunk_size=100):
# in case if procu2.php tries to clean/patch user file in root dirs,
# it will be marked as non-existent due to 'Permission denied'
# error which confuses users, consider it as unable to cleanup/patch
if result[hit].not_exist():
if hit.orig_file_path.exists():
unprocessed.append(hit)
else:
not_exist.append(hit)
if not_exist:
VulnerabilityHit.delete_hits([hit.orig_file for hit in not_exist])
patched, failed = [], []
for hit in processed:
# treat as failed unless success is explicitly stated
if result[hit].is_patched():
patched.append(hit)
else:
failed.append(hit)
self._set_hit_status(patched, VulnerabilityHitStatus.PATCHED, now)
if unable_to_path := unprocessed + failed:
self._set_hit_status(
unable_to_path, VulnerabilityHitStatus.VULNERABLE
)