Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fuzz/fuzz_targets/iommu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ fuzz_target!(|bytes: &[u8]| -> Corpus {
SeccompAction::Allow,
EventFd::new(EFD_NONBLOCK).unwrap(),
((MEM_SIZE - IOVA_SPACE_SIZE) as u64, (MEM_SIZE - 1) as u64),
64,
Copy link
Contributor

@russell-islam russell-islam Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use something like, const MAX_IOMMU_ADDRESS_WIDTH_BITS: u8 = 64;

Copy link
Contributor Author

@edigaryev edigaryev Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you've asked me to split the commits above, and @rbradford stated that commits should be independently compilable (atomic), this would not be possible to achieve because:

  • virtio-devices: iommu: allow limiting maximum address width in bits contains this line
  • vmm: introduce platform option to limit maximum IOMMU address width contains the MAX_IOMMU_ADDRESS_WIDTH_BITS constant

None,
)
.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ fn create_app(default_vcpus: String, default_memory: String, default_rng: String
.arg(
Arg::new("platform")
.long("platform")
.help("num_pci_segments=<num_pci_segments>,iommu_segments=<list_of_segments>,serial_number=<dmi_device_serial_number>,uuid=<dmi_device_uuid>,oem_strings=<list_of_strings>")
.help("num_pci_segments=<num_pci_segments>,iommu_segments=<list_of_segments>,iommu_address_width=<bits>,serial_number=<dmi_device_serial_number>,uuid=<dmi_device_uuid>,oem_strings=<list_of_strings>")
.num_args(1)
.group("vm-config"),
)
Expand Down
42 changes: 41 additions & 1 deletion tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9111,7 +9111,10 @@ mod vfio {
.args(["--cpus", "boot=4"])
.args(["--memory", "size=4G"])
.args(["--kernel", fw_path(FwType::RustHypervisorFirmware).as_str()])
.args(["--device", format!("path={NVIDIA_VFIO_DEVICE}").as_str()])
.args([
"--device",
format!("path={NVIDIA_VFIO_DEVICE},iommu=on").as_str(),
])
.args(["--api-socket", &api_socket])
.default_disks()
.default_net()
Expand All @@ -9136,6 +9139,43 @@ mod vfio {

handle_child_output(r, &output);
}

#[test]
fn test_nvidia_card_iommu_address_width() {
let jammy = UbuntuDiskConfig::new(JAMMY_VFIO_IMAGE_NAME.to_string());
let guest = Guest::new(Box::new(jammy));
let api_socket = temp_api_path(&guest.tmp_dir);

let mut child = GuestCommand::new(&guest)
.args(["--cpus", "boot=4"])
.args(["--memory", "size=4G"])
.args(["--kernel", fw_path(FwType::RustHypervisorFirmware).as_str()])
.args(["--device", format!("path={NVIDIA_VFIO_DEVICE}").as_str()])
.args([
"--platform",
"num_pci_segments=2,iommu_segments=1,iommu_address_width=42",
])
.args(["--api-socket", &api_socket])
.default_disks()
.default_net()
.capture_output()
.spawn()
.unwrap();

let r = std::panic::catch_unwind(|| {
guest.wait_vm_boot(None).unwrap();

assert!(guest
.ssh_command("sudo dmesg")
.unwrap()
.contains("input address: 42 bits"));
});

let _ = child.kill();
let output = child.wait_with_output().unwrap();

handle_child_output(r, &output);
}
}

mod live_migration {
Expand Down
13 changes: 11 additions & 2 deletions virtio-devices/src/iommu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -906,9 +906,10 @@ impl Iommu {
seccomp_action: SeccompAction,
exit_evt: EventFd,
msi_iova_space: (u64, u64),
address_width_bits: u8,
state: Option<IommuState>,
) -> io::Result<(Self, Arc<IommuMapping>)> {
let (avail_features, acked_features, endpoints, domains, paused) =
let (mut avail_features, acked_features, endpoints, domains, paused) =
if let Some(state) = state {
info!("Restoring virtio-iommu {}", id);
(
Expand Down Expand Up @@ -939,12 +940,20 @@ impl Iommu {
(avail_features, 0, BTreeMap::new(), BTreeMap::new(), false)
};

let config = VirtioIommuConfig {
let mut config = VirtioIommuConfig {
page_size_mask: VIRTIO_IOMMU_PAGE_SIZE_MASK,
probe_size: PROBE_PROP_SIZE,
..Default::default()
};

if address_width_bits < 64 {
avail_features |= 1u64 << VIRTIO_IOMMU_F_INPUT_RANGE;
config.input_range = VirtioIommuRange64 {
start: 0,
end: (1u64 << address_width_bits) - 1,
}
}

let mapping = Arc::new(IommuMapping {
endpoints: Arc::new(RwLock::new(endpoints)),
domains: Arc::new(RwLock::new(domains)),
Expand Down
3 changes: 3 additions & 0 deletions vmm/src/api/openapi/cloud-hypervisor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,9 @@ components:
items:
type: integer
format: int16
iommu_address_width:
type: integer
format: uint8
serial_number:
type: string
uuid:
Expand Down
31 changes: 31 additions & 0 deletions vmm/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::landlock::LandlockAccess;
use crate::vm_config::*;

const MAX_NUM_PCI_SEGMENTS: u16 = 96;
const MAX_IOMMU_ADDRESS_WIDTH_BITS: u8 = 64;

/// Errors associated with VM configuration parameters.
#[derive(Debug, Error)]
Expand Down Expand Up @@ -183,6 +184,8 @@ pub enum ValidationError {
InvalidPciSegment(u16),
/// Invalid PCI segment aperture weight
InvalidPciSegmentApertureWeight(u32),
/// Invalid IOMMU address width in bits
InvalidIommuAddressWidthBits(u8),
/// Balloon too big
BalloonLargerThanRam(u64, u64),
/// On a IOMMU segment but not behind IOMMU
Expand Down Expand Up @@ -309,6 +312,9 @@ impl fmt::Display for ValidationError {
InvalidPciSegmentApertureWeight(aperture_weight) => {
write!(f, "Invalid PCI segment aperture weight: {aperture_weight}")
}
InvalidIommuAddressWidthBits(iommu_address_width_bits) => {
write!(f, "IOMMU address width in bits ({iommu_address_width_bits}) should be less than or equal to {MAX_IOMMU_ADDRESS_WIDTH_BITS}")
}
BalloonLargerThanRam(balloon_size, ram_size) => {
write!(
f,
Expand Down Expand Up @@ -817,6 +823,7 @@ impl PlatformConfig {
parser
.add("num_pci_segments")
.add("iommu_segments")
.add("iommu_address_width")
.add("serial_number")
.add("uuid")
.add("oem_strings");
Expand All @@ -834,6 +841,10 @@ impl PlatformConfig {
.convert::<IntegerList>("iommu_segments")
.map_err(Error::ParsePlatform)?
.map(|v| v.0.iter().map(|e| *e as u16).collect());
let iommu_address_width_bits: u8 = parser
.convert("iommu_address_width")
.map_err(Error::ParsePlatform)?
.unwrap_or(MAX_IOMMU_ADDRESS_WIDTH_BITS);
let serial_number = parser
.convert("serial_number")
.map_err(Error::ParsePlatform)?;
Expand All @@ -857,6 +868,7 @@ impl PlatformConfig {
Ok(PlatformConfig {
num_pci_segments,
iommu_segments,
iommu_address_width_bits,
serial_number,
uuid,
oem_strings,
Expand All @@ -882,6 +894,12 @@ impl PlatformConfig {
}
}

if self.iommu_address_width_bits > MAX_IOMMU_ADDRESS_WIDTH_BITS {
return Err(ValidationError::InvalidIommuAddressWidthBits(
self.iommu_address_width_bits,
));
}

Ok(())
}
}
Expand Down Expand Up @@ -3998,6 +4016,7 @@ mod tests {
PlatformConfig {
num_pci_segments: MAX_NUM_PCI_SEGMENTS,
iommu_segments: None,
iommu_address_width_bits: MAX_IOMMU_ADDRESS_WIDTH_BITS,
serial_number: None,
uuid: None,
oem_strings: None,
Expand Down Expand Up @@ -4296,6 +4315,18 @@ mod tests {
Err(ValidationError::InvalidPciSegment(MAX_NUM_PCI_SEGMENTS + 1))
);

let mut invalid_config = valid_config.clone();
invalid_config.platform = Some(PlatformConfig {
iommu_address_width_bits: MAX_IOMMU_ADDRESS_WIDTH_BITS + 1,
..platform_fixture()
});
assert_eq!(
invalid_config.validate(),
Err(ValidationError::InvalidIommuAddressWidthBits(
MAX_IOMMU_ADDRESS_WIDTH_BITS + 1
))
);

let mut still_valid_config = valid_config.clone();
still_valid_config.platform = Some(PlatformConfig {
iommu_segments: Some(vec![1, 2, 3]),
Expand Down
11 changes: 10 additions & 1 deletion vmm/src/device_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ use crate::pci_segment::PciSegment;
use crate::serial_manager::{Error as SerialManagerError, SerialManager};
use crate::vm_config::{
ConsoleOutputMode, DeviceConfig, DiskConfig, FsConfig, NetConfig, PmemConfig, UserDeviceConfig,
VdpaConfig, VhostMode, VmConfig, VsockConfig, DEFAULT_PCI_SEGMENT_APERTURE_WEIGHT,
VdpaConfig, VhostMode, VmConfig, VsockConfig, DEFAULT_IOMMU_ADDRESS_WIDTH_BITS,
DEFAULT_PCI_SEGMENT_APERTURE_WEIGHT,
};
use crate::{device_node, GuestRegionMmap, PciDeviceInfo, DEVICE_MANAGER_SNAPSHOT_ID};

Expand Down Expand Up @@ -1365,6 +1366,13 @@ impl DeviceManager {
) -> DeviceManagerResult<()> {
let iommu_id = String::from(IOMMU_DEVICE_NAME);

let iommu_address_width_bits =
if let Some(ref platform) = self.config.lock().unwrap().platform {
platform.iommu_address_width_bits
} else {
DEFAULT_IOMMU_ADDRESS_WIDTH_BITS
};

let iommu_device = if self.config.lock().unwrap().iommu {
let (device, mapping) = virtio_devices::Iommu::new(
iommu_id.clone(),
Expand All @@ -1373,6 +1381,7 @@ impl DeviceManager {
.try_clone()
.map_err(DeviceManagerError::EventFd)?,
self.get_msi_iova_space(),
iommu_address_width_bits,
state_from_id(self.snapshot.as_ref(), iommu_id.as_str())
.map_err(DeviceManagerError::RestoreGetState)?,
)
Expand Down
7 changes: 7 additions & 0 deletions vmm/src/vm_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,19 @@ pub fn default_platformconfig_num_pci_segments() -> u16 {
DEFAULT_NUM_PCI_SEGMENTS
}

pub const DEFAULT_IOMMU_ADDRESS_WIDTH_BITS: u8 = 64;
pub fn default_platformconfig_iommu_address_width_bits() -> u8 {
DEFAULT_IOMMU_ADDRESS_WIDTH_BITS
}

#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
pub struct PlatformConfig {
#[serde(default = "default_platformconfig_num_pci_segments")]
pub num_pci_segments: u16,
#[serde(default)]
pub iommu_segments: Option<Vec<u16>>,
#[serde(default = "default_platformconfig_iommu_address_width_bits")]
pub iommu_address_width_bits: u8,
#[serde(default)]
pub serial_number: Option<String>,
#[serde(default)]
Expand Down
Loading