Jump to content

Module:CineMol/model

From Wikipedia, the free encyclopedia
-- This is a port of CineMol to lua
-- CineMol https://github.com/moltools/CineMol was written by David Meijer, Marnix H. Medema & Justin J. J. van der Hooft and is MIT licensed
-- Please consider any edits made to this page as dual licensed MIT & CC-BY-SA 4.0

local p = {}

local geometry = require( 'Module:CineMol/geometry' )
local style = require( 'Module:CineMol/style' )
local cinemolsvg = require( 'Module:CineMol/svg' )
local calculate_convex_hull = require( 'Module:CineMol/fitting' ).calculate_convex_hull

local Cylinder = geometry.Cylinder 
local Line3D = geometry.Line3D 
local Point2D = geometry.Point2D 
local Point3D = geometry.Point3D 
local Sphere = geometry.Sphere 
local cylinder_intersects_with_cylinder = geometry.cylinder_intersects_with_cylinder 
local get_points_on_surface_cylinder = geometry.get_points_on_surface_cylinder 
local get_points_on_surface_sphere = geometry.get_points_on_surface_sphere 
local point_is_inside_cylinder = geometry.point_is_inside_cylinder 
local point_is_inside_sphere = geometry.point_is_inside_sphere 
local sphere_intersects_with_cylinder = geometry.sphere_intersects_with_cylinder 
local sphere_intersects_with_sphere = geometry.sphere_intersects_with_sphere 
local checkType = geometry.checkType

local Cartoon = style.Cartoon
local Color = style.Color
local Depiction = style.Depiction
local Fill = style.Fill
local FillStyle = style.FillStyle
local Glossy = style.Glossy
local LinearGradient = style.LinearGradient
local RadialGradient = style.RadialGradient
local Solid = style.Solid
local Wire = style.Wire

local Line2D = cinemolsvg.Line2D
local Polygon2D = cinemolsvg.Polygon2D
local Svg = cinemolsvg.Svg
local ViewBox = cinemolsvg.ViewBox

-- End of imports

-- Apply focal length to a 3D point.
function p.apply_focal_length(point, focal_length)
	checkType('apply_focal_length', 1, point, 'Point3D')
	checkType('apply_focal_length', 2, focal_length, 'number')

	local x,y,z = point.x, point.y, point.z
	local x_proj, y_proj
	if (focal_length - z) > 0 then
        x_proj = (focal_length * x) / (focal_length - z)
        y_proj = (focal_length * y) / (focal_length - z)
    else
        -- Handle the case when z = -focal_length to avoid division by zero
       x_proj, y_proj = x, y
	end

    return Point2D(x_proj, y_proj)
end

function p.ModelSphere(geometry, depiction)
	checkType( 'ModelSphere', 1, geometry, 'Sphere' )
	checkType( 'ModelSphere', 2, depiction, 'Depiction' )
	return {
		_TYPE = 'ModelSphere',
		geometry = geometry,
		depiction = depiction
	}
end

function p.ModelCylinder(geometry, depiction)
	checkType( 'ModelCylinder', 1, geometry, 'Cylinder' )
	checkType( 'ModelCylinder', 2, depiction, 'Depiction' )
	return {
		_TYPE = 'ModelCylinder',
		geometry = geometry,
		depiction = depiction
	}
end

function p.ModelWire(geometry, color, width, opacity)
	checkType( 'ModelWire', 1, geometry, 'Line3D' )
	checkType( 'ModelWire', 2, color, 'Color' )
	checkType( 'ModelWire', 3, width, 'number' )
	checkType( 'ModelWire', 4, opacity, 'number' )
	return {
		_TYPE = 'ModelWire',
		geometry = geometry,
		color = color,
		width = width,
		opacity = opacity
	}
end

-- ==============
-- Create visible 2D polygon from node geometry
-- ==============

-- Get the vertices of the polygon that represents the visible part of the node
function p.get_node_polygon_vertices( this, others, resolution, focal_length )
	-- Seems like this should be a method of Model base class, but original doesn't do that, so following that.
	checkType( 'get_node_polygon_verticies', 2, others, 'table' )
	checkType( 'get_node_polygon_verticies', 3, resolution, 'number' )

	local points
	if this._TYPE == 'ModelSphere' then
		points = get_points_on_surface_sphere( this.geometry, resolution, resolution, true )
	elseif this._TYPE == 'ModelCylinder' then
		points = get_points_on_surface_cylinder(this.geometry, math.floor(resolution / 2.0))
	else
		-- If node is not a sphere or cylinder (i.e., unsupported geometries), return empty list.
		return {}
	end

	local visible_points = {}
	for i, point in ipairs(points) do
		local loopBreak = false
	    for j, node in ipairs(others) do
            if node._TYPE == 'ModelSphere' and point_is_inside_sphere(node.geometry, point) then
				loopBreak = true
                break
			elseif node._TYPE == 'ModelCylinder' and point_is_inside_cylinder(node.geometry, point) then
				loopBreak = true
                break
			end
        end
		if not loopBreak then
            local x, y, z = point.x, point.y, point.z

			local am_point
            if focal_length ~= nil then
                am_point = p.apply_focal_length(Point3D(x, y, z), focal_length)
            else
                am_point = Point2D(x, y)
			end
            visible_points[#visible_points+1] = am_point
		end
	end

    -- If no visible points, return empty list.
    if #visible_points == 0 then
        return {}
	end

    -- Calculate convex hull of visible points.
    local inds = calculate_convex_hull(visible_points)
	local verts = {}
	for i,ind in ipairs( inds ) do
		verts[#verts+1] = visible_points[ind]
	end

    return verts
end

-- =============
-- Draw scene
-- =============

-- Filter, rotate, and scale nodes based on what to include in the scene.
function p.prepare_nodes_for_intersecting( nodes_to_sort, include_spheres, include_cylinders, include_wires, rotation_over_x_axis, rotation_over_y_axis, rotation_over_z_axis, scale )
	rotation_over_x_axis = rotation_over_x_axis == nil and 0 or rotation_over_x_axis
	rotation_over_y_axis = rotation_over_y_axis == nil and 0 or rotation_over_y_axis
	rotation_over_z_axis = rotation_over_z_axis == nil and 0 or rotation_over_z_axis
	checkType( 'prepare_nodes_for_intersecting', 1, nodes_to_sort, 'table' )
	checkType( 'prepare_nodes_for_intersecting', 5, rotation_over_x_axis, 'number' )
	checkType( 'prepare_nodes_for_intersecting', 6, rotation_over_y_axis, 'number' )
	checkType( 'prepare_nodes_for_intersecting', 7, rotation_over_z_axis, 'number' )

	local nodes = {}
	for _, node in ipairs( nodes_to_sort ) do
        if node._TYPE == 'ModelSphere' and include_spheres then
            node.geometry.center = node.geometry.center:rotate(
                rotation_over_x_axis, rotation_over_y_axis, rotation_over_z_axis
            )

            if scale ~= nil then
                node.geometry.radius = node.geometry.radius * scale 
                node.geometry.center = Point3D(
                    node.geometry.center.x * scale,
                    node.geometry.center.y * scale,
                    node.geometry.center.z * scale
                )
            
                if node.depiction.name == 'Cartoon' then
                    node.depiction.outline_width = node.depiction.outline_width * scale
				end
			end

            nodes[#nodes+1] = node

        elseif node._TYPE == 'ModelCylinder' and include_cylinders then
            node.geometry.start = node.geometry.start:rotate(
                rotation_over_x_axis, rotation_over_y_axis, rotation_over_z_axis
            )
            node.geometry.endp = node.geometry.endp:rotate(
                rotation_over_x_axis, rotation_over_y_axis, rotation_over_z_axis
            )

            if scale ~= nil then
                node.geometry.radius = node.geometry.radius * scale
                node.geometry.start = Point3D(
                    node.geometry.start.x * scale,
                    node.geometry.start.y * scale,
                    node.geometry.start.z * scale
                )
                node.geometry.endp = Point3D(
                    node.geometry.endp.x * scale,
                    node.geometry.endp.y * scale,
                    node.geometry.endp.z * scale
                )

                if node.depiction.name == 'Cartoon' then
                    node.depiction.outline_width = node.depiction.outline_width * scale
				end
			end

            nodes[#nodes+1] = node

        elseif node._TYPE == 'ModelWire' and include_wires then
            node.geometry.start = node.geometry.start:rotate(
                rotation_over_x_axis, rotation_over_y_axis, rotation_over_z_axis
            )
            node.geometry.endp = node.geometry.endp:rotate(
                rotation_over_x_axis, rotation_over_y_axis, rotation_over_z_axis
            )

            if scale ~= nil then
                node.geometry.start = Point3D(
                    node.geometry.start.x * scale,
                    node.geometry.start.y * scale,
                    node.geometry.start.z * scale
                )
                node.geometry.endp = Point3D(
                    node.geometry.endp.x * scale,
                    node.geometry.endp.y * scale,
                    node.geometry.endp.z * scale
                )

                node.width = node.width * scale
			end

            nodes[#nodes+1] = node
		end
	end
	return nodes
end

-- Create a fill for a node.
function p.create_fill( node, reference )
	checkType( 'create_fill', 1, node, 'table' )
	checkType( 'create_fill', 2, reference, 'string' )

	local fill = nil
	if node._TYPE == 'ModelWire' then
		local stroke_color = node.color
		local stroke_width = node.width
		local opacity = node.opacity
		local style = Wire( stroke_color, stroke_width, opacity )
		fill = Fill(reference, style)
	elseif node.depiction.name == 'Cartoon' then
        local fill_color = node.depiction.fill_color
        local stroke_color = node.depiction.outline_color
        local stroke_width = node.depiction.outline_width
        local opacity = node.depiction.opacity
        local style = Solid(fill_color, stroke_color, stroke_width, opacity)
        fill = Fill(reference, style)

    -- Glossy style is different for spheres and cylinders.
    elseif (
        node.depiction.name == 'Glossy'
        and node._TYPE == 'ModelSphere'
        and node.geometry._TYPE == 'Sphere'
    ) then
        if node.geometry._TYPE ~= 'Sphere' then
            error("Node geometry of ModelSphere must be a sphere.")
		end

        local x, y = node.geometry.center.x, node.geometry.center.y
        local fill_color = node.depiction.fill_color
        local center = Point2D(x, y)
        local radius = node.geometry.radius
        local opacity = node.depiction.opacity
        local style = RadialGradient(fill_color, center, radius, opacity)
        fill = Fill(reference, style)
    elseif node.depiction.name == 'Glossy' and node._TYPE == 'ModelCylinder' then
        if node.geometry._TYPE ~= 'Cylinder' then
            error("Node geometry of ModelCylinder must be a cylinder.")
		end

        local start_x, start_y = node.geometry.start.x, node.geometry.start.y
        local end_x, end_y = node.geometry.endp.x, node.geometry.endp.y
        local fill_color = node.depiction.fill_color
        local start_center = Point2D(start_x, start_y)
        local end_center = Point2D(end_x, end_y)
        local radius = node.geometry.radius
        local opacity = node.depiction.opacity
        local style = LinearGradient(fill_color, start_center, end_center, opacity)
        fill = Fill(reference, style)
	end

    return fill
end

-- Calculate which of the previous nodes intersect with the current node.
function p.calculate_intersecting_nodes(
	node,
	other_nodes,
	calculate_sphere_sphere_intersections,
	calculate_sphere_cylinder_intersections,
	calculate_cylinder_sphere_intersections,
	calculate_cylinder_cylinder_intersections,
	filter_nodes_for_intersecting
)
	checkType( 'calculate_intersecting_nodes', 1, node, 'table' )
	checkType( 'calculate_intersecting_nodes', 2, other_nodes, 'table' )

	local previous_nodes = {}
    -- Wireframe is drawn as a line and has no intersections with other nodes.
    if node._TYPE ~= 'ModelWire' then
        for _, prev_node in ipairs(other_nodes) do

            if not filter_nodes_for_intersecting then
                previous_nodes[#previous_nodes+1] = prev_node
            elseif (
                node._TYPE == 'ModelSphere'
                and prev_node._TYPE == 'ModelSphere'
                and calculate_sphere_sphere_intersections
            ) then
                if sphere_intersects_with_sphere(node.geometry, prev_node.geometry) then
                    previous_nodes[#previous_nodes+1] = prev_node
				end
            elseif (
                node._TYPE == 'ModelSphere' and prev_node._TYPE == 'ModelCylinder'
            ) and calculate_sphere_cylinder_intersections then
                if sphere_intersects_with_cylinder(node.geometry, prev_node.geometry) then
                    previous_nodes[#previous_nodes+1] = prev_node
				end
            elseif (
                node._TYPE == 'ModelCylinder' and prev_node._TYPE == 'ModelSphere'
            ) and calculate_cylinder_sphere_intersections then
                if sphere_intersects_with_cylinder(prev_node.geometry, node.geometry) then
                    previous_nodes[#previous_nodes+1] = prev_node
				end
            elseif (
                node._TYPE == 'ModelCylinder'
                and prev_node._TYPE == 'ModelCylinder'
                and calculate_cylinder_cylinder_intersections
            ) then
                if cylinder_intersects_with_cylinder(node.geometry, prev_node.geometry) then
                    previous_nodes[#previous_nodes+1] = prev_node
				end
			end
		end
	end

    return previous_nodes
end

local function first_n(table, n)
	local res = {}
	for i,v in ipairs(table) do
		if i >= n then
			break
		end
		res[#res+1] = v
	end
	return res
end

function p.Scene( nodes )
	nodes = nodes == nil and {} or nodes
	checkType( 'Scene', 1, nodes, 'table' )
	local obj = {
		_TYPE = 'Scene',
		nodes = nodes
	}

	function obj:add_node( node )
		checkType( 'add_node', 1, self, 'Scene' )
		self.nodes[#self.nodes+1] = node
	end

	function obj:calculate_view_box( points, margin )
		checkType( 'Scene:calculate_view_box', 1, self, 'Scene' )
		checkType( 'Scene:calculate_view_box', 2, points, 'table' )
		checkType( 'Scene:calculate_view_box', 3, margin, 'number' )

        if #points == 0 then
			-- original is missing return, presumably a bug.
            return ViewBox(0, 0, 0, 0)
		end

        local min_x, min_y, max_x, max_y = 1/0, 1/0, -1/0, -1/0
        for _, point in ipairs(points) do
            min_x = math.min(min_x, point.x)
            min_y = math.min(min_y, point.y)
            max_x = math.max(max_x, point.x)
            max_y = math.max(max_y, point.y)
		end

        min_x = min_x - margin
        min_y = min_y - margin
        max_x = max_x + margin
        max_y = max_y + margin

        local width = max_x - min_x
        local height = max_y - min_y
        return ViewBox(min_x, min_y, width, height)
	end

	-- Draw the scene.
	function obj:draw(
		resolution,
		window, -- window might not work due to how we make SVGs. Instead set img attribute on svg
		view_box,
		rotation_over_x_axis,
		rotation_over_y_axis,
		rotation_over_z_axis,
		include_spheres,
		include_cylinders,
		include_wires,
		calculate_sphere_sphere_intersections,
		calculate_sphere_cylinder_intersections,
		calculate_cylinder_sphere_intersections,
		calculate_cylinder_cylinder_intersections,
		filter_nodes_for_intersecting,
		scale,
		focal_length
	)
		checkType( 'Scene:draw', 1, self, 'Scene' )
		checkType( 'Scene:draw', 2, resolution, 'number' )

		rotation_over_x_axis = rotation_over_x_axis == nil and 0 or rotation_over_x_axis
		rotation_over_y_axis = rotation_over_y_axis == nil and 0 or rotation_over_y_axis
		rotation_over_z_axis = rotation_over_z_axis == nil and 0 or rotation_over_z_axis
		include_spheres = include_spheres == nil and true or include_spheres
		include_cylinders = include_cylinders == nil and true or include_cylinders
		include_wires = include_wires == nil and true or include_wires
		calculate_sphere_sphere_intersections = calculate_sphere_sphere_intersections == nil and true or calculate_sphere_sphere_intersections
		calculate_sphere_cylinder_intersections = calculate_sphere_cylinder_intersections == nil and true or calculate_sphere_cylinder_intersections
		calculate_cylinder_sphere_intersections = calculate_cylinder_sphere_intersections == nil and true or calculate_cylinder_sphere_intersections
		calculate_cylinder_cylinder_intersections = calculate_cylinder_cylinder_intersections == nil and true or calculate_cylinder_cylinder_intersections
		scale = scale == nil and 1.0 or scale
		checkType( 'Scene:draw', 16, scale, 'number' )

	    -- Filter geometries.
        local nodes = p.prepare_nodes_for_intersecting(
            self.nodes,
            include_spheres,
            include_cylinders,
            include_wires,
            rotation_over_x_axis,
            rotation_over_y_axis,
            rotation_over_z_axis,
            scale
        )
        -- Get sorting values for nodes. We sort on z-coordinate as we always look at the
        -- scene from the z-axis, towards the origin.
		-- We also use node number to make it a stable sort to emulate the ordering of the python project.
        local sorting_values = {}
        for i, node in ipairs(nodes) do
        	local midpoint_z
            if node._TYPE == 'ModelSphere' then
                sorting_values[#sorting_values+1] = { node.geometry.center.z, i, node }
            elseif node._TYPE == 'ModelCylinder' then
                local cylinder_start = node.geometry.start
                local cylinder_end = node.geometry.endp
                midpoint_z = (cylinder_start.z + cylinder_end.z) / 2
                sorting_values[#sorting_values+1] = { midpoint_z, i, node }
            elseif node._TYPE == 'ModelWire' then
                local wire_start = node.geometry.start
                local wire_end = node.geometry.endp
                midpoint_z = (wire_start.z + wire_end.z) / 2
                sorting_values[#sorting_values+1] = { midpoint_z, i, node }
			end
		end


        -- Sort nodes by sorting values.
		table.sort( sorting_values, function( a, b )
			if a[1] == b[1] then
				return a[2] < b[2]
			end
			return a[1] < b[1]
		end )
		local nodes = {}
		for _,v in ipairs( sorting_values ) do
			nodes[#nodes+1] = v[3]
		end

        -- Keep track of reference points for determining viewbox later on.
        local ref_points = {}

        -- Calculate 2D shape and fill for each node.
        local objects = {}
        local fills = {}
        -- Loop over nodes and create shapes and fills to populate the SVG.
        for i, node in ipairs(nodes) do

            -- Create reference tag for node to connect shape to style.
			-- subtract 1 from i to make it match the python version for easier testing.
            local reference = "node-" .. (i-1)

            -- Calculate which of the previously drawn nodes intersect with the current node.
            local previous_nodes = p.calculate_intersecting_nodes(
                node,
                first_n(nodes, i),
                calculate_sphere_sphere_intersections,
                calculate_sphere_cylinder_intersections,
                calculate_cylinder_sphere_intersections,
                calculate_cylinder_cylinder_intersections,
                filter_nodes_for_intersecting
            )

            -- Calculate line for wire.
            if node._TYPE == 'ModelWire' then
				local start, endp
                if focal_length  ~= nil then
                    start = p.apply_focal_length(node.geometry.start, focal_length)
                    endp = p.apply_focal_length(node.geometry.endp, focal_length)
                else
                    start = Point2D(node.geometry.start.x, node.geometry.start.y)
                    endp = Point2D(node.geometry.endp.x, node.geometry.endp.y)
                end

                local line = Line2D(reference, start, endp)

                if view_box == nil then
                    ref_points[#ref_points+1] = start
                    ref_points[#ref_points+1] = endp
				end
                objects[#objects+1] = line

            -- Otherwise, calculate polygon for visible part of node.
            else
                local points = p.get_node_polygon_vertices(node, previous_nodes, resolution, focal_length)
                local polygon = Polygon2D(reference, points)

                if view_box == nil then
					for _,pt in ipairs( points ) do
						ref_points[#ref_points+1] = pt
                    end
				end

                objects[#objects+1] = polygon
			end
            -- Create style.
            local fill = p.create_fill(node, reference)
            if fill ~= nil then
                fills[#fills+1] = fill
			end

            -- padding = len(str(len(nodes)))
            -- logger.info(f" Drawing node {i + 1:>{padding}} of {len(nodes)}")
		end

        -- Calculate view box.
        if view_box == nil then
            view_box = self:calculate_view_box(ref_points, 5.0)
        end

		-- view_box, window, background_color, fills, objects
        local svg = Svg(view_box, window, nil, fills, objects)
        return svg
	end

	return obj
end


return p