Pythonによるレイトレーシング

プリミティブは球体のみ、シェーディングはフラットシェーディング、カメラは正投影、解像度固定。

import math
from itertools import product
from functools import partial


def range2(rx, ry):
    return product(range(rx), range(ry))


class Vector(object):
    def __init__(self, x, y, z):
        self.x = x
        self.y = y
        self.z = z

    def normalize(self):
        v = math.sqrt(self.x*self.x + self.y*self.y + self.z*self.z)
        self.x /= v
        self.y /= v
        self.z /= v

    def __mul__(self, v):
        return self.x*v.x + self.y*v.y + self.z*v.z

    def __sub__(self, v):
        return Vector(self.x - v.x, self.y - v.y, self.z - v.z)

    def printval(self):
        print self.x, self.y, self.z


class Color(Vector):
    def __init__(self, r, g, b):
        super(Color, self).__init__(r, g, b)

    def printval(self):
        print '%d %d %d' % (self.x*255, self.y*255, self.z*255)


class Point(Vector):
    def __init__(self, x, y, z):
        super(Point, self).__init__(x, y, z)


class Ray(object):
    def __init__(self, p, v):
        self.p = p
        self.v = v
        self.v.normalize()


class Primitive(object):
    def __init__(self):
        pass

    def intersect(self, ray):
        return (None, None)

    def color(self):
        pass


class Back(Primitive):
    def __init__(self):
        super(Back, self).__init__()

    def intersect(self, ray):
        return (self, float('inf'))

    def color(self):
        return Color(0.5, 0.5, 0.5)


class Sphere(Primitive):
    def __init__(self, p, rad, col):
        self.cp = p
        self.rad = rad
        self.col = col

    def _pmin(self, a, b):
        if b < 0.0:
            return None
        if a < 0.0:
            return b
        return a

    def intersect(self, ray):
        v = ray.p - self.cp
        n = v*ray.v
        det = n*n - v*v + self.rad*self.rad
        if det < 0.0:
            return (self, None)
        d = math.sqrt(det)
        t = self._pmin(-n - d, -n + d)
        return (self, t)

    def color(self):
        return self.col


def create_ray():
    rx = 200
    ry = 200
    f = lambda r, x: 1.0 - 2.0/r*x
    v = Vector(0.0, 0.0, -1.0)
    rs = [Ray(Point(f(rx, x), f(ry, y), 2.0), v) for x, y in range2(rx, ry)]
    return rs


def get_objs():
    return [Sphere(Point(0.0, 0.0, 0.0), 0.5, Color(1.0, 0.0, 0.0)),
            Sphere(Point(0.5, 0.5, -1.0), 0.5, Color(0.0, 1.0, 0.0)),
            Back()]


def trace(ray):
    objs = get_objs()
    func = lambda ray, obj: obj.intersect(ray)
    ts = map(partial(func, ray), objs)
    ts = filter(lambda x: x[1], ts)
    ts.sort(key=lambda x: x[1])
    return ts[0][0].color()


def output(cs):
    print 'P3\n200 200\n255'
    map(lambda c: c.printval(), cs)


def main():
    rs = create_ray()
    cs = map(trace, rs)
    output(cs)

main()

index
Mail